tests: improve RAGFlow coverage based on Codecov report (#13219)

### What problem does this PR solve?

Codecov’s coverage report shows that several RAGFlow code paths are
currently untested or under-tested. This makes it easier for regressions
to slip in during refactors and feature work.
This PR adds targeted automated tests to cover the files and branches
highlighted by Codecov, improving confidence in core behavior while
keeping runtime functionality unchanged.

### Type of change

- [x] Other (please describe): Test coverage improvement (adds/extends
unit and integration tests to address Codecov-reported gaps)
This commit is contained in:
6ba3i
2026-02-26 19:03:26 +08:00
committed by GitHub
parent 1aa49a11f0
commit 22c4d72891
26 changed files with 11107 additions and 13 deletions

View File

@ -125,7 +125,8 @@ class TestDatasetUpdate:
@pytest.mark.p1
@given(name=valid_names())
@example("a" * 128)
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
# Network-bound API call; disable Hypothesis deadline to avoid flaky timeouts.
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
def test_name(self, HttpApiAuth, add_dataset_func, name):
dataset_id = add_dataset_func
payload = {"name": name}

View File

@ -77,6 +77,15 @@ class _StubResponse:
self.headers = _StubHeaders()
class _DummyUploadFile:
def __init__(self, filename):
self.filename = filename
self.saved_path = None
async def save(self, path):
self.saved_path = path
def _run(coro):
return asyncio.run(coro)
@ -96,6 +105,16 @@ async def _collect_stream(body):
return items
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_session_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
@ -1125,3 +1144,355 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch):
assert retrieval_capture["rank_feature"] == ["label-1"]
assert retrieval_capture["rerank_mdl"] is not None
assert any(call[1] == module.LLMType.EMBEDDING.value and call[3].get("llm_name") == "embd-model" for call in llm_calls)
llm_calls.clear()
async def _fake_keyword_extraction(_chat_mdl, question):
return f"-{question}-keywords"
async def _fake_kg_retrieval(question, tenant_ids, kb_ids, _embd_mdl, _chat_mdl):
return {
"id": "kg-chunk",
"question": question,
"tenant_ids": tenant_ids,
"kb_ids": kb_ids,
"content_with_weight": 1,
"vector": [0.5],
}
monkeypatch.setattr(module, "keyword_extraction", _fake_keyword_extraction)
monkeypatch.setattr(module.settings, "kg_retriever", SimpleNamespace(retrieval=_fake_kg_retrieval))
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue(
{
"kb_id": "kb-1",
"question": "keyword-q",
"rerank_id": "manual-reranker",
"keyword": True,
"use_kg": True,
}
),
)
monkeypatch.setattr(
module.KnowledgebaseService,
"get_by_id",
lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-model")),
)
res = _run(handler())
assert res["code"] == 0
assert res["data"]["chunks"][0]["id"] == "kg-chunk"
assert all("vector" not in chunk for chunk in res["data"]["chunks"])
assert any(call[1] == module.LLMType.RERANK.value for call in llm_calls)
async def _raise_not_found(*_args, **_kwargs):
raise RuntimeError("x not_found y")
monkeypatch.setattr(module.settings, "retriever", SimpleNamespace(retrieval=_raise_not_found))
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"kb_id": "kb-1", "question": "q"}),
)
res = _run(handler())
assert res["message"] == "No chunk found! Check the chunk status please!"
@pytest.mark.p2
def test_searchbots_related_questions_embedded_matrix_unit(monkeypatch):
module = _load_session_module(monkeypatch)
handler = inspect.unwrap(module.related_questions_embedded)
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer"}))
res = _run(handler())
assert res["message"] == "Authorization is not valid!"
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer bad"}))
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [])
res = _run(handler())
assert "API key is invalid" in res["message"]
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer ok"}))
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="")])
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"question": "q"}))
res = _run(handler())
assert res["message"] == "permission denined."
captured = {}
class _FakeChatBundle:
async def async_chat(self, prompt, messages, options):
captured["prompt"] = prompt
captured["messages"] = messages
captured["options"] = options
return "1. Alpha\n2. Beta\nignored"
def _fake_bundle(*args, **_kwargs):
captured["bundle_args"] = args
return _FakeChatBundle()
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"question": "solar", "search_id": "search-1"}),
)
monkeypatch.setattr(
module.SearchService,
"get_detail",
lambda _search_id: {"search_config": {"chat_id": "chat-x", "llm_setting": {"temperature": 0.2}}},
)
monkeypatch.setattr(module, "LLMBundle", _fake_bundle)
res = _run(handler())
assert res["code"] == 0
assert res["data"] == ["Alpha", "Beta"]
assert captured["bundle_args"] == ("tenant-1", module.LLMType.CHAT, "chat-x")
assert captured["options"] == {"temperature": 0.2}
assert "Keywords: solar" in captured["messages"][0]["content"]
@pytest.mark.p2
def test_searchbots_detail_share_embedded_matrix_unit(monkeypatch):
module = _load_session_module(monkeypatch)
handler = inspect.unwrap(module.detail_share_embedded)
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer"}, args={"search_id": "s-1"}))
res = _run(handler())
assert res["message"] == "Authorization is not valid!"
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer bad"}, args={"search_id": "s-1"}))
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [])
res = _run(handler())
assert "API key is invalid" in res["message"]
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer ok"}, args={"search_id": "s-1"}))
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="")])
res = _run(handler())
assert res["message"] == "permission denined."
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-a")])
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [])
res = _run(handler())
assert res["code"] == module.RetCode.OPERATING_ERROR
assert "Has no permission for this operation." in res["message"]
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [SimpleNamespace(id="s-1")])
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: None)
res = _run(handler())
assert res["message"] == "Can't find this Search App!"
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"id": "s-1", "name": "search-app"})
res = _run(handler())
assert res["code"] == 0
assert res["data"]["id"] == "s-1"
@pytest.mark.p2
def test_searchbots_mindmap_embedded_matrix_unit(monkeypatch):
module = _load_session_module(monkeypatch)
handler = inspect.unwrap(module.mindmap)
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer"}))
res = _run(handler())
assert res["message"] == "Authorization is not valid!"
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer bad"}))
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [])
res = _run(handler())
assert "API key is invalid" in res["message"]
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer ok"}))
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"question": "q", "kb_ids": ["kb-1"]}))
captured = {}
async def _gen_ok(question, kb_ids, tenant_id, search_config):
captured["params"] = (question, kb_ids, tenant_id, search_config)
return {"nodes": [question]}
monkeypatch.setattr(module, "gen_mindmap", _gen_ok)
res = _run(handler())
assert res["code"] == 0
assert res["data"] == {"nodes": ["q"]}
assert captured["params"] == ("q", ["kb-1"], "tenant-1", {})
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"question": "q2", "kb_ids": ["kb-1"], "search_id": "search-1"}),
)
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"mode": "graph"}})
res = _run(handler())
assert res["code"] == 0
assert captured["params"] == ("q2", ["kb-1"], "tenant-1", {"mode": "graph"})
async def _gen_error(*_args, **_kwargs):
return {"error": "mindmap boom"}
monkeypatch.setattr(module, "gen_mindmap", _gen_error)
res = _run(handler())
assert "mindmap boom" in res["message"]
@pytest.mark.p2
def test_sequence2txt_embedded_validation_and_stream_matrix_unit(monkeypatch):
module = _load_session_module(monkeypatch)
handler = inspect.unwrap(module.sequence2txt)
monkeypatch.setattr(module, "Response", _StubResponse)
monkeypatch.setattr(module.tempfile, "mkstemp", lambda suffix: (11, f"/tmp/audio{suffix}"))
monkeypatch.setattr(module.os, "close", lambda _fd: None)
def _set_request(form, files):
monkeypatch.setattr(
module,
"request",
SimpleNamespace(form=_AwaitableValue(form), files=_AwaitableValue(files)),
)
_set_request({"stream": "false"}, {})
res = _run(handler("tenant-1"))
assert "Missing 'file' in multipart form-data" in res["message"]
_set_request({"stream": "false"}, {"file": _DummyUploadFile("bad.txt")})
res = _run(handler("tenant-1"))
assert "Unsupported audio format: .txt" in res["message"]
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [])
res = _run(handler("tenant-1"))
assert res["message"] == "Tenant not found!"
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "asr_id": ""}])
res = _run(handler("tenant-1"))
assert res["message"] == "No default ASR model is set"
class _SyncASR:
def transcription(self, _path):
return "transcribed text"
def stream_transcription(self, _path):
return []
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "asr_id": "asr-x"}])
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncASR())
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup fail")))
res = _run(handler("tenant-1"))
assert res["code"] == 0
assert res["data"]["text"] == "transcribed text"
class _StreamASR:
def transcription(self, _path):
return ""
def stream_transcription(self, _path):
yield {"event": "partial", "text": "hello"}
_set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _StreamASR())
monkeypatch.setattr(module.os, "remove", lambda _path: None)
resp = _run(handler("tenant-1"))
assert isinstance(resp, _StubResponse)
assert resp.content_type == "text/event-stream"
chunks = _run(_collect_stream(resp.body))
assert any('"event": "partial"' in chunk for chunk in chunks)
class _ErrorASR:
def transcription(self, _path):
return ""
def stream_transcription(self, _path):
raise RuntimeError("stream asr boom")
_set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")})
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _ErrorASR())
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup boom")))
resp = _run(handler("tenant-1"))
chunks = _run(_collect_stream(resp.body))
assert any("stream asr boom" in chunk for chunk in chunks)
@pytest.mark.p2
def test_tts_embedded_stream_and_error_matrix_unit(monkeypatch):
module = _load_session_module(monkeypatch)
handler = inspect.unwrap(module.tts)
monkeypatch.setattr(module, "Response", _StubResponse)
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"text": "A。B"}))
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [])
res = _run(handler("tenant-1"))
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "tts_id": ""}])
res = _run(handler("tenant-1"))
assert res["message"] == "No default TTS model is set"
class _TTSOk:
def tts(self, txt):
if not txt:
return []
yield f"chunk-{txt}".encode("utf-8")
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "tts_id": "tts-x"}])
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
resp = _run(handler("tenant-1"))
assert resp.mimetype == "audio/mpeg"
assert resp.headers.get("Cache-Control") == "no-cache"
assert resp.headers.get("Connection") == "keep-alive"
assert resp.headers.get("X-Accel-Buffering") == "no"
chunks = _run(_collect_stream(resp.body))
assert any("chunk-A" in chunk for chunk in chunks)
assert any("chunk-B" in chunk for chunk in chunks)
class _TTSErr:
def tts(self, _txt):
raise RuntimeError("tts boom")
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSErr())
resp = _run(handler("tenant-1"))
chunks = _run(_collect_stream(resp.body))
assert any('"code": 500' in chunk and "**ERROR**: tts boom" in chunk for chunk in chunks)
@pytest.mark.p2
def test_build_reference_chunks_metadata_matrix_unit(monkeypatch):
module = _load_session_module(monkeypatch)
monkeypatch.setattr(module, "chunks_format", lambda _reference: [{"dataset_id": "kb-1", "document_id": "doc-1"}])
res = module._build_reference_chunks([], include_metadata=False)
assert res == [{"dataset_id": "kb-1", "document_id": "doc-1"}]
monkeypatch.setattr(module, "chunks_format", lambda _reference: [{"dataset_id": "kb-1"}, {"document_id": "doc-2"}])
res = module._build_reference_chunks([], include_metadata=True)
assert all("document_metadata" not in chunk for chunk in res)
monkeypatch.setattr(module, "chunks_format", lambda _reference: [{"dataset_id": "kb-1", "document_id": "doc-1"}])
monkeypatch.setattr(module.DocMetadataService, "get_metadata_for_documents", lambda _doc_ids, _kb_id: {"doc-1": {"author": "alice"}})
res = module._build_reference_chunks([], include_metadata=True, metadata_fields=[1, None])
assert "document_metadata" not in res[0]
source_chunks = [
{"dataset_id": "kb-1", "document_id": "doc-1"},
{"dataset_id": "kb-2", "document_id": "doc-2"},
{"dataset_id": "kb-1", "document_id": "doc-3"},
{"dataset_id": "kb-1", "document_id": None},
]
monkeypatch.setattr(module, "chunks_format", lambda _reference: [dict(chunk) for chunk in source_chunks])
def _get_metadata(_doc_ids, kb_id):
if kb_id == "kb-1":
return {"doc-1": {"author": "alice", "year": 2024}}
if kb_id == "kb-2":
return {"doc-2": {"author": "bob", "tag": "rag"}}
return {}
monkeypatch.setattr(module.DocMetadataService, "get_metadata_for_documents", _get_metadata)
res = module._build_reference_chunks([], include_metadata=True, metadata_fields=["author", "missing", 3])
assert res[0]["document_metadata"] == {"author": "alice"}
assert res[1]["document_metadata"] == {"author": "bob"}
assert "document_metadata" not in res[2]
assert "document_metadata" not in res[3]

View File

@ -17,6 +17,7 @@
import pytest
from ragflow_sdk import RAGFlow
from ragflow_sdk.modules.agent import Agent
from ragflow_sdk.modules.session import Session
class _DummyResponse:
@ -27,6 +28,16 @@ class _DummyResponse:
return self._payload
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
@pytest.mark.p2
def test_list_agents_success_and_error(monkeypatch):
client = RAGFlow("token", "http://localhost:9380")
@ -122,3 +133,84 @@ def test_delete_agent_success_and_error(monkeypatch):
with pytest.raises(Exception) as exception_info:
client.delete_agent("agent-1")
assert "delete boom" in str(exception_info.value), str(exception_info.value)
@pytest.mark.p2
def test_agent_and_dsl_default_initialization():
client = RAGFlow("token", "http://localhost:9380")
agent = Agent(client, {"id": "agent-1", "title": "Agent One"})
assert agent.id == "agent-1"
assert agent.avatar is None
assert agent.canvas_type is None
assert agent.description is None
assert agent.dsl is None
dsl = Agent.Dsl(client, {})
assert dsl.answer == []
assert "begin" in dsl.components
assert dsl.components["begin"]["obj"]["component_name"] == "Begin"
assert dsl.graph["nodes"][0]["id"] == "begin"
assert dsl.history == []
assert dsl.messages == []
assert dsl.path == []
assert dsl.reference == []
@pytest.mark.p2
def test_agent_session_methods_success_and_error_paths(monkeypatch):
client = RAGFlow("token", "http://localhost:9380")
agent = Agent(client, {"id": "agent-1"})
calls = {"post": [], "get": [], "rm": []}
def _ok_post(path, json=None, stream=False, files=None):
calls["post"].append((path, json, stream, files))
return _DummyResponse({"code": 0, "data": {"id": "session-1", "agent_id": "agent-1", "name": "one"}})
def _ok_get(path, params=None):
calls["get"].append((path, params))
return _DummyResponse(
{
"code": 0,
"data": [
{"id": "session-1", "agent_id": "agent-1", "name": "one"},
{"id": "session-2", "agent_id": "agent-1", "name": "two"},
],
}
)
def _ok_rm(path, payload):
calls["rm"].append((path, payload))
return _DummyResponse({"code": 0, "message": "ok"})
monkeypatch.setattr(agent, "post", _ok_post)
monkeypatch.setattr(agent, "get", _ok_get)
monkeypatch.setattr(agent, "rm", _ok_rm)
session = agent.create_session(name="session-name")
assert isinstance(session, Session), str(session)
assert session.id == "session-1", str(session)
assert calls["post"][-1][0] == "/agents/agent-1/sessions"
assert calls["post"][-1][1] == {"name": "session-name"}
sessions = agent.list_sessions(page=2, page_size=5, orderby="create_time", desc=False, id="session-1")
assert len(sessions) == 2, str(sessions)
assert all(isinstance(item, Session) for item in sessions), str(sessions)
assert calls["get"][-1][0] == "/agents/agent-1/sessions"
assert calls["get"][-1][1]["page"] == 2
assert calls["get"][-1][1]["id"] == "session-1"
agent.delete_sessions(ids=["session-1", "session-2"])
assert calls["rm"][-1] == ("/agents/agent-1/sessions", {"ids": ["session-1", "session-2"]})
monkeypatch.setattr(agent, "post", lambda *_args, **_kwargs: _DummyResponse({"code": 1, "message": "create failed"}))
with pytest.raises(Exception, match="create failed"):
agent.create_session(name="bad")
monkeypatch.setattr(agent, "get", lambda *_args, **_kwargs: _DummyResponse({"code": 2, "message": "list failed"}))
with pytest.raises(Exception, match="list failed"):
agent.list_sessions()
monkeypatch.setattr(agent, "rm", lambda *_args, **_kwargs: _DummyResponse({"code": 3, "message": "delete failed"}))
with pytest.raises(Exception, match="delete failed"):
agent.delete_sessions(ids=["session-1"])

View File

@ -17,6 +17,28 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from configs import SESSION_WITH_CHAT_NAME_LIMIT
from ragflow_sdk import RAGFlow
from ragflow_sdk.modules.session import Session
class _DummyStreamResponse:
def __init__(self, lines):
self._lines = lines
def iter_lines(self, decode_unicode=True):
del decode_unicode
for line in self._lines:
yield line
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
@pytest.mark.usefixtures("clear_session_with_chat_assistants")
@ -74,3 +96,72 @@ class TestSessionWithChatAssistantCreate:
with pytest.raises(Exception) as exception_info:
chat_assistant.create_session(name="valid_name")
assert "You do not own the assistant" in str(exception_info.value)
@pytest.mark.p2
def test_session_module_streaming_and_helper_paths_unit(monkeypatch):
client = RAGFlow("token", "http://localhost:9380")
chat_session = Session(client, {"id": "session-chat", "chat_id": "chat-1"})
chat_done_session = Session(client, {"id": "session-chat-done", "chat_id": "chat-1"})
agent_session = Session(client, {"id": "session-agent", "agent_id": "agent-1"})
calls = []
chat_stream = _DummyStreamResponse(
[
"",
"data: {bad json}",
'data: {"event":"workflow_started","data":{"content":"skip"}}',
'{"data":{"answer":"chat-answer","reference":{"chunks":[{"id":"chunk-1"}]}}}',
'data: {"data": true}',
"data: [DONE]",
]
)
agent_stream = _DummyStreamResponse(
[
"data: {bad json}",
'data: {"event":"message","data":{"content":"agent-answer"}}',
'data: {"event":"message_end","data":{"content":"done"}}',
]
)
def _chat_post(path, json=None, stream=False, files=None):
calls.append(("chat", path, json, stream, files))
return chat_stream
def _agent_post(path, json=None, stream=False, files=None):
calls.append(("agent", path, json, stream, files))
return agent_stream
monkeypatch.setattr(chat_session, "post", _chat_post)
monkeypatch.setattr(
chat_done_session,
"post",
lambda *_args, **_kwargs: _DummyStreamResponse(
['{"data":{"answer":"chat-done","reference":{"chunks":[]}}}', "data: [DONE]"]
),
)
monkeypatch.setattr(agent_session, "post", _agent_post)
chat_messages = list(chat_session.ask("hello chat", stream=True, temperature=0.2))
assert len(chat_messages) == 1
assert chat_messages[0].content == "chat-answer"
assert chat_messages[0].reference == [{"id": "chunk-1"}]
chat_done_messages = list(chat_done_session.ask("hello done", stream=True))
assert len(chat_done_messages) == 1
assert chat_done_messages[0].content == "chat-done"
agent_messages = list(agent_session.ask("hello agent", stream=True, top_p=0.8))
assert len(agent_messages) == 1
assert agent_messages[0].content == "agent-answer"
assert calls[0][1] == "/chats/chat-1/completions"
assert calls[0][2]["question"] == "hello chat"
assert calls[0][2]["session_id"] == "session-chat"
assert calls[0][2]["temperature"] == 0.2
assert calls[0][3] is True
assert calls[1][1] == "/agents/agent-1/completions"
assert calls[1][2]["question"] == "hello agent"
assert calls[1][2]["session_id"] == "session-agent"
assert calls[1][2]["top_p"] == 0.8
assert calls[1][3] is True

View File

@ -15,6 +15,18 @@
#
import pytest
from concurrent.futures import ThreadPoolExecutor, as_completed
from ragflow_sdk import RAGFlow
from ragflow_sdk.modules.session import Message, Session
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
class TestSessionsWithChatAssistantList:
@ -215,3 +227,54 @@ class TestSessionsWithChatAssistantList:
with pytest.raises(Exception) as exception_info:
chat_assistant.list_sessions()
assert "You don't own the assistant" in str(exception_info.value)
@pytest.mark.p2
def test_session_module_error_paths_unit(monkeypatch):
client = RAGFlow("token", "http://localhost:9380")
unknown_session = Session(client, {"id": "session-unknown", "chat_id": "chat-1"})
unknown_session._Session__session_type = "unknown" # noqa: SLF001
with pytest.raises(Exception) as exception_info:
list(unknown_session.ask("hello", stream=False))
assert "Unknown session type" in str(exception_info.value)
bad_json_session = Session(client, {"id": "session-bad-json", "chat_id": "chat-1"})
class _BadJsonResponse:
def json(self):
raise ValueError("json decode failed")
monkeypatch.setattr(bad_json_session, "post", lambda *_args, **_kwargs: _BadJsonResponse())
with pytest.raises(Exception) as exception_info:
list(bad_json_session.ask("hello", stream=False))
assert "Invalid response" in str(exception_info.value)
ok_json_session = Session(client, {"id": "session-ok-json", "chat_id": "chat-1"})
class _OkJsonResponse:
def json(self):
return {"data": {"answer": "ok-answer", "reference": {"chunks": [{"id": "chunk-ok"}]}}}
monkeypatch.setattr(ok_json_session, "post", lambda *_args, **_kwargs: _OkJsonResponse())
ok_messages = list(ok_json_session.ask("hello", stream=False))
assert len(ok_messages) == 1
assert ok_messages[0].content == "ok-answer"
assert ok_messages[0].reference == [{"id": "chunk-ok"}]
transport_session = Session(client, {"id": "session-transport", "chat_id": "chat-1"})
monkeypatch.setattr(
transport_session,
"post",
lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("transport boom")),
)
with pytest.raises(RuntimeError) as exception_info:
list(transport_session.ask("hello", stream=False))
assert "transport boom" in str(exception_info.value)
message = Message(client, {})
assert message.content == "Hi! I am your assistant, can I help you?"
assert message.reference is None
assert message.role == "assistant"
assert message.prompt is None
assert message.id is None

View File

@ -0,0 +1,197 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import sys
import urllib.parse
from pathlib import Path
from types import ModuleType
import pytest
class _FakeResponse:
def __init__(self, payload=None, err=None):
self._payload = payload or {}
self._err = err
def raise_for_status(self):
if self._err:
raise self._err
def json(self):
return self._payload
def _base_config(scope="openid profile"):
return {
"client_id": "client-1",
"client_secret": "secret-1",
"authorization_url": "https://issuer.example/authorize",
"token_url": "https://issuer.example/token",
"userinfo_url": "https://issuer.example/userinfo",
"redirect_uri": "https://app.example/callback",
"scope": scope,
}
def _load_oauth_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
http_client_mod = ModuleType("common.http_client")
async def _default_async_request(*_args, **_kwargs):
return _FakeResponse({})
def _default_sync_request(*_args, **_kwargs):
return _FakeResponse({})
http_client_mod.async_request = _default_async_request
http_client_mod.sync_request = _default_sync_request
monkeypatch.setitem(sys.modules, "common.http_client", http_client_mod)
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
apps_pkg = ModuleType("api.apps")
apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
auth_pkg = ModuleType("api.apps.auth")
auth_pkg.__path__ = [str(repo_root / "api" / "apps" / "auth")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
monkeypatch.setitem(sys.modules, "api.apps.auth", auth_pkg)
sys.modules.pop("api.apps.auth.oauth", None)
oauth_path = repo_root / "api" / "apps" / "auth" / "oauth.py"
oauth_spec = importlib.util.spec_from_file_location("api.apps.auth.oauth", oauth_path)
oauth_module = importlib.util.module_from_spec(oauth_spec)
monkeypatch.setitem(sys.modules, "api.apps.auth.oauth", oauth_module)
oauth_spec.loader.exec_module(oauth_module)
return oauth_module
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
@pytest.mark.p2
def test_oauth_client_sync_matrix_unit(monkeypatch):
oauth_module = _load_oauth_module(monkeypatch)
client = oauth_module.OAuthClient(_base_config())
assert client.client_id == "client-1"
assert client.client_secret == "secret-1"
assert client.authorization_url.endswith("/authorize")
assert client.token_url.endswith("/token")
assert client.userinfo_url.endswith("/userinfo")
assert client.redirect_uri.endswith("/callback")
assert client.scope == "openid profile"
assert client.http_request_timeout == 7
info = oauth_module.UserInfo("u@example.com", "user1", "User One", "avatar-url")
assert info.to_dict() == {
"email": "u@example.com",
"username": "user1",
"nickname": "User One",
"avatar_url": "avatar-url",
}
auth_url = client.get_authorization_url(state="s p/a?ce")
parsed = urllib.parse.urlparse(auth_url)
query = urllib.parse.parse_qs(parsed.query)
assert parsed.scheme == "https"
assert query["client_id"] == ["client-1"]
assert query["redirect_uri"] == ["https://app.example/callback"]
assert query["response_type"] == ["code"]
assert query["scope"] == ["openid profile"]
assert query["state"] == ["s p/a?ce"]
no_scope_client = oauth_module.OAuthClient(_base_config(scope=None))
no_scope_query = urllib.parse.parse_qs(urllib.parse.urlparse(no_scope_client.get_authorization_url()).query)
assert "scope" not in no_scope_query
call_log = []
def _sync_ok(method, url, data=None, headers=None, timeout=None):
call_log.append((method, url, data, headers, timeout))
if url.endswith("/token"):
return _FakeResponse({"access_token": "token-1"})
return _FakeResponse({"email": "user@example.com", "picture": "id-picture"})
monkeypatch.setattr(oauth_module, "sync_request", _sync_ok)
token = client.exchange_code_for_token("code-1")
assert token["access_token"] == "token-1"
user_info = client.fetch_user_info("access-1")
assert isinstance(user_info, oauth_module.UserInfo)
assert user_info.to_dict() == {
"email": "user@example.com",
"username": "user",
"nickname": "user",
"avatar_url": "id-picture",
}
assert call_log[0][0] == "POST"
assert call_log[0][3]["Accept"] == "application/json"
assert call_log[1][0] == "GET"
assert call_log[1][3]["Authorization"] == "Bearer access-1"
normalized = client.normalize_user_info(
{"email": "fallback@example.com", "username": "fallback-user", "nickname": "fallback-nick", "avatar_url": "direct-avatar"}
)
assert normalized.to_dict()["avatar_url"] == "direct-avatar"
monkeypatch.setattr(oauth_module, "sync_request", lambda *_args, **_kwargs: _FakeResponse(err=RuntimeError("status boom")))
with pytest.raises(ValueError, match="Failed to exchange authorization code for token: status boom"):
client.exchange_code_for_token("code-2")
with pytest.raises(ValueError, match="Failed to fetch user info: status boom"):
client.fetch_user_info("access-2")
@pytest.mark.p2
def test_oauth_client_async_matrix_unit(monkeypatch):
oauth_module = _load_oauth_module(monkeypatch)
client = oauth_module.OAuthClient(_base_config())
async def _async_ok(method, url, data=None, headers=None, **kwargs):
_ = (method, data, headers, kwargs.get("timeout"))
if url.endswith("/token"):
return _FakeResponse({"access_token": "token-async"})
return _FakeResponse({"email": "async@example.com", "username": "async-user", "nickname": "Async User", "avatar_url": "async-avatar"})
monkeypatch.setattr(oauth_module, "async_request", _async_ok)
token = asyncio.run(client.async_exchange_code_for_token("code-a"))
assert token["access_token"] == "token-async"
info = asyncio.run(client.async_fetch_user_info("async-token"))
assert info.to_dict() == {
"email": "async@example.com",
"username": "async-user",
"nickname": "Async User",
"avatar_url": "async-avatar",
}
async def _async_fail(*_args, **_kwargs):
return _FakeResponse(err=RuntimeError("async boom"))
monkeypatch.setattr(oauth_module, "async_request", _async_fail)
with pytest.raises(ValueError, match="Failed to exchange authorization code for token: async boom"):
asyncio.run(client.async_exchange_code_for_token("code-b"))
with pytest.raises(ValueError, match="Failed to fetch user info: async boom"):
asyncio.run(client.async_fetch_user_info("async-token-2"))

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@
import asyncio
import base64
import importlib.util
import json
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
@ -71,6 +72,7 @@ class _DummyRetCode:
SUCCESS = 0
DATA_ERROR = 102
EXCEPTION_ERROR = 100
OPERATING_ERROR = 103
class _DummyParserType:
@ -204,6 +206,7 @@ def _load_chunk_module(monkeypatch):
class _DummyLLMType:
EMBEDDING = SimpleNamespace(value="embedding")
CHAT = SimpleNamespace(value="chat")
RERANK = SimpleNamespace(value="rerank")
constants_mod.RetCode = _DummyRetCode
constants_mod.LLMType = _DummyLLMType
@ -365,6 +368,11 @@ def _set_request_json(monkeypatch, module, payload):
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
@pytest.mark.p2
def test_list_chunk_exception_branches_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
@ -663,7 +671,180 @@ def test_create_chunk_guards_pagerank_and_success_unit(monkeypatch):
)
res = _run(module.create())
assert res["code"] == 0, res
assert res["data"]["chunk_id"], res
assert module.settings.docStoreConn.inserted, "insert should be called"
inserted = module.settings.docStoreConn.inserted[-1]
assert "pagerank_flt" in inserted
assert module.DocumentService.increment_calls, "increment_chunk_num should be called"
async def _raise_thread_pool(_func):
raise RuntimeError("create tp boom")
monkeypatch.setattr(module, "thread_pool_exec", _raise_thread_pool)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk"})
res = _run(module.create())
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
assert "create tp boom" in res["message"], res
@pytest.mark.p2
def test_retrieval_test_branch_matrix_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
module.request = SimpleNamespace(headers={"X-Request-ID": "req-r"}, args={})
applied_filters = []
llm_calls = []
cross_calls = []
keyword_calls = []
async def _apply_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids):
applied_filters.append(
{
"meta_data_filter": meta_data_filter,
"metas": metas,
"question": question,
"chat_mdl": chat_mdl,
"local_doc_ids": list(local_doc_ids),
}
)
return ["doc-filtered"]
async def _cross_languages(_tenant_id, _dialog, question, langs):
cross_calls.append((question, tuple(langs)))
return f"{question}-xl"
async def _keyword_extraction(_chat_mdl, question):
keyword_calls.append(question)
return "-kw"
class _Retriever:
def __init__(self, mode="ok"):
self.mode = mode
self.retrieval_questions = []
async def retrieval(self, question, *_args, **_kwargs):
if self.mode == "not_found":
raise Exception("boom not_found boom")
if self.mode == "explode":
raise RuntimeError("retrieval boom")
self.retrieval_questions.append(question)
return {"chunks": [{"id": "c1", "vector": [0.1], "content_with_weight": "chunk-content"}]}
def retrieval_by_children(self, chunks, _tenant_ids):
return list(chunks)
class _KgRetriever:
async def retrieval(self, *_args, **_kwargs):
return {"id": "kg-1", "content_with_weight": "kg-content"}
class _NoContentKgRetriever:
async def retrieval(self, *_args, **_kwargs):
return {"id": "kg-2", "content_with_weight": ""}
monkeypatch.setattr(module, "LLMBundle", lambda *args, **kwargs: llm_calls.append((args, kwargs)) or SimpleNamespace())
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: [{"meta": "v"}], raising=False)
monkeypatch.setattr(module, "apply_meta_data_filter", _apply_filter)
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"meta_data_filter": {"method": "auto"}, "chat_id": "chat-1"}}, raising=False)
monkeypatch.setattr(module, "cross_languages", _cross_languages)
monkeypatch.setattr(module, "keyword_extraction", _keyword_extraction)
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: ["lbl"])
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [_DummyTenant("tenant-1")])
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False, raising=False)
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "question": "q", "search_id": "search-1"})
res = _run(module.retrieval_test())
assert res["code"] == module.RetCode.OPERATING_ERROR, res
assert "Only owner of dataset authorized for this operation." in res["message"], res
assert applied_filters and applied_filters[-1]["meta_data_filter"]["method"] == "auto"
assert llm_calls, "search_id metadata auto branch should instantiate chat model"
_set_request_json(monkeypatch, module, {"kb_id": [], "question": "q"})
res = _run(module.retrieval_test())
assert res["code"] == module.RetCode.DATA_ERROR, res
assert "Please specify dataset firstly." in res["message"], res
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: True, raising=False)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None), raising=False)
_set_request_json(
monkeypatch,
module,
{"kb_id": ["kb-1"], "question": "q", "meta_data_filter": {"method": "semi_auto"}},
)
res = _run(module.retrieval_test())
assert res["code"] == module.RetCode.DATA_ERROR, res
assert "Knowledgebase not found!" in res["message"], res
retriever = _Retriever(mode="ok")
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-1")), raising=False)
monkeypatch.setattr(module.settings, "retriever", retriever)
monkeypatch.setattr(module.settings, "kg_retriever", _KgRetriever(), raising=False)
_set_request_json(
monkeypatch,
module,
{
"kb_id": ["kb-1"],
"question": "q",
"cross_languages": ["fr"],
"rerank_id": "rerank-1",
"keyword": True,
"use_kg": True,
},
)
res = _run(module.retrieval_test())
assert res["code"] == 0, res
assert cross_calls[-1] == ("q", ("fr",))
assert keyword_calls[-1] == "q-xl"
assert retriever.retrieval_questions[-1] == "q-xl-kw"
assert res["data"]["chunks"][0]["id"] == "kg-1", res
assert all("vector" not in chunk for chunk in res["data"]["chunks"])
monkeypatch.setattr(module.settings, "kg_retriever", _NoContentKgRetriever(), raising=False)
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q", "use_kg": True})
res = _run(module.retrieval_test())
assert res["code"] == 0, res
assert res["data"]["chunks"][0]["id"] == "c1", res
monkeypatch.setattr(module.settings, "retriever", _Retriever(mode="not_found"))
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q"})
res = _run(module.retrieval_test())
assert res["code"] == module.RetCode.DATA_ERROR, res
assert "No chunk found! Check the chunk status please!" in res["message"], res
monkeypatch.setattr(module.settings, "retriever", _Retriever(mode="explode"))
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q"})
res = _run(module.retrieval_test())
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
assert "retrieval boom" in res["message"], res
@pytest.mark.p2
def test_knowledge_graph_repeat_deal_matrix_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
module.request = SimpleNamespace(args={"doc_id": "doc-1"}, headers={})
payload = {
"id": "root",
"children": [
{"id": "dup"},
{"id": "dup", "children": [{"id": "dup"}]},
],
}
class _SRes:
ids = ["bad-json", "mind-map"]
field = {
"bad-json": {"knowledge_graph_kwd": "graph", "content_with_weight": "{bad json"},
"mind-map": {"knowledge_graph_kwd": "mind_map", "content_with_weight": json.dumps(payload)},
}
async def _search(*_args, **_kwargs):
return _SRes()
monkeypatch.setattr(module.settings.retriever, "search", _search)
res = _run(module.knowledge_graph())
assert res["code"] == 0, res
assert res["data"]["graph"] == {}, res
mind_map = res["data"]["mind_map"]
assert mind_map["children"][0]["id"] == "dup", res
assert mind_map["children"][1]["id"] == "dup(1)", res
assert mind_map["children"][1]["children"][0]["id"] == "dup(2)", res

View File

@ -0,0 +1,711 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import json
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _Args(dict):
def get(self, key, default=None, type=None):
value = super().get(key, default)
if type is None:
return value
try:
return type(value)
except (TypeError, ValueError):
return default
def to_dict(self, flat=True):
return dict(self)
class _FakeResponse:
def __init__(self, body, status_code):
self.body = body
self.status_code = status_code
self.headers = {}
class _FakeConnectorRecord:
def __init__(self, payload):
self._payload = payload
def to_dict(self):
return dict(self._payload)
class _FakeCredentials:
def __init__(self, raw='{"refresh_token":"rt","access_token":"at"}'):
self._raw = raw
def to_json(self):
return self._raw
class _FakeFlow:
def __init__(self, client_config, scopes):
self.client_config = client_config
self.scopes = scopes
self.redirect_uri = None
self.credentials = _FakeCredentials()
self.auth_kwargs = None
self.token_code = None
def authorization_url(self, **kwargs):
self.auth_kwargs = dict(kwargs)
return f"https://oauth.example/{kwargs['state']}", kwargs["state"]
def fetch_token(self, code):
self.token_code = code
class _FakeBoxToken:
def __init__(self, access_token, refresh_token):
self.access_token = access_token
self.refresh_token = refresh_token
class _FakeBoxOAuth:
def __init__(self, config):
self.config = config
self.exchange_code = None
def get_authorize_url(self, options):
return f"https://box.example/auth?state={options.state}&redirect={options.redirect_uri}"
def get_tokens_authorization_code_grant(self, code):
self.exchange_code = code
def retrieve_token(self):
return _FakeBoxToken("box-access", "box-refresh")
class _FakeRedis:
def __init__(self):
self.store = {}
self.set_calls = []
self.deleted = []
def get(self, key):
return self.store.get(key)
def set_obj(self, key, obj, ttl):
self.set_calls.append((key, obj, ttl))
self.store[key] = json.dumps(obj)
def delete(self, key):
self.deleted.append(key)
self.store.pop(key, None)
def _run(coro):
return asyncio.run(coro)
def _set_request(module, *, args=None, json_body=None):
module.request = SimpleNamespace(
args=_Args(args or {}),
json=_AwaitableValue({} if json_body is None else json_body),
)
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_connector_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="tenant-1")
apps_mod.login_required = lambda fn: fn
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
db_mod = ModuleType("api.db")
db_mod.InputType = SimpleNamespace(POLL="POLL")
monkeypatch.setitem(sys.modules, "api.db", db_mod)
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
connector_service_mod = ModuleType("api.db.services.connector_service")
class _StubConnectorService:
@staticmethod
def update_by_id(*_args, **_kwargs):
return True
@staticmethod
def save(**_kwargs):
return True
@staticmethod
def get_by_id(_connector_id):
return True, _FakeConnectorRecord({"id": _connector_id})
@staticmethod
def list(_tenant_id):
return []
@staticmethod
def resume(*_args, **_kwargs):
return True
@staticmethod
def rebuild(*_args, **_kwargs):
return None
@staticmethod
def delete_by_id(*_args, **_kwargs):
return True
class _StubSyncLogsService:
@staticmethod
def list_sync_tasks(*_args, **_kwargs):
return [], 0
connector_service_mod.ConnectorService = _StubConnectorService
connector_service_mod.SyncLogsService = _StubSyncLogsService
monkeypatch.setitem(sys.modules, "api.db.services.connector_service", connector_service_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
async def _get_request_json():
return {}
api_utils_mod.get_request_json = _get_request_json
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
"code": code,
"message": message,
"data": data,
}
api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
"code": code,
"message": message,
"data": data,
}
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
constants_mod = ModuleType("common.constants")
constants_mod.RetCode = SimpleNamespace(
ARGUMENT_ERROR=101,
SERVER_ERROR=500,
RUNNING=102,
PERMISSION_ERROR=403,
)
constants_mod.TaskStatus = SimpleNamespace(SCHEDULE="schedule", CANCEL="cancel")
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
config_mod = ModuleType("common.data_source.config")
config_mod.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = "https://example.com/drive"
config_mod.GMAIL_WEB_OAUTH_REDIRECT_URI = "https://example.com/gmail"
config_mod.BOX_WEB_OAUTH_REDIRECT_URI = "https://example.com/box"
config_mod.DocumentSource = SimpleNamespace(GMAIL="gmail", GOOGLE_DRIVE="google-drive")
monkeypatch.setitem(sys.modules, "common.data_source.config", config_mod)
google_constants_mod = ModuleType("common.data_source.google_util.constant")
google_constants_mod.WEB_OAUTH_POPUP_TEMPLATE = (
"<html><head><title>{title}</title></head>"
"<body><h1>{heading}</h1><p>{message}</p><script>{payload_json}</script><script>{auto_close}</script></body></html>"
)
google_constants_mod.GOOGLE_SCOPES = {
config_mod.DocumentSource.GMAIL: ["scope-gmail"],
config_mod.DocumentSource.GOOGLE_DRIVE: ["scope-drive"],
}
monkeypatch.setitem(sys.modules, "common.data_source.google_util.constant", google_constants_mod)
misc_mod = ModuleType("common.misc_utils")
misc_mod.get_uuid = lambda: "uuid-from-helper"
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_mod)
rag_pkg = ModuleType("rag")
rag_pkg.__path__ = [str(repo_root / "rag")]
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
rag_utils_pkg = ModuleType("rag.utils")
rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")]
monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg)
redis_mod = ModuleType("rag.utils.redis_conn")
redis_mod.REDIS_CONN = _FakeRedis()
monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod)
quart_mod = ModuleType("quart")
quart_mod.request = SimpleNamespace(args=_Args(), json=_AwaitableValue({}))
async def _make_response(body, status_code):
return _FakeResponse(body, status_code)
quart_mod.make_response = _make_response
monkeypatch.setitem(sys.modules, "quart", quart_mod)
google_pkg = ModuleType("google_auth_oauthlib")
google_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "google_auth_oauthlib", google_pkg)
google_flow_mod = ModuleType("google_auth_oauthlib.flow")
class _StubFlow:
@classmethod
def from_client_config(cls, client_config, scopes):
return _FakeFlow(client_config, scopes)
google_flow_mod.Flow = _StubFlow
monkeypatch.setitem(sys.modules, "google_auth_oauthlib.flow", google_flow_mod)
box_mod = ModuleType("box_sdk_gen")
class _OAuthConfig:
def __init__(self, client_id, client_secret):
self.client_id = client_id
self.client_secret = client_secret
class _GetAuthorizeUrlOptions:
def __init__(self, redirect_uri, state):
self.redirect_uri = redirect_uri
self.state = state
box_mod.BoxOAuth = _FakeBoxOAuth
box_mod.OAuthConfig = _OAuthConfig
box_mod.GetAuthorizeUrlOptions = _GetAuthorizeUrlOptions
monkeypatch.setitem(sys.modules, "box_sdk_gen", box_mod)
module_path = repo_root / "api" / "apps" / "connector_app.py"
spec = importlib.util.spec_from_file_location("test_connector_routes_unit", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_connector_basic_routes_and_task_controls(monkeypatch):
module = _load_connector_app(monkeypatch)
async def _no_sleep(_secs):
return None
monkeypatch.setattr(module.asyncio, "sleep", _no_sleep)
records = {"conn-1": _FakeConnectorRecord({"id": "conn-1", "source": "drive"})}
update_calls = []
save_calls = []
resume_calls = []
delete_calls = []
monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda cid, payload: update_calls.append((cid, payload)))
def _save(**payload):
save_calls.append(payload)
records[payload["id"]] = _FakeConnectorRecord(payload)
monkeypatch.setattr(module.ConnectorService, "save", _save)
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, records[cid]))
monkeypatch.setattr(module.ConnectorService, "list", lambda tenant_id: [{"id": "listed", "tenant": tenant_id}])
monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda cid, page, page_size: ([{"id": "log-1"}], 9))
monkeypatch.setattr(module.ConnectorService, "resume", lambda cid, status: resume_calls.append((cid, status)))
monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda cid: delete_calls.append(cid))
monkeypatch.setattr(module, "get_uuid", lambda: "generated-id")
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}}),
)
res = _run(module.set_connector())
assert update_calls == [("conn-1", {"refresh_freq": 7, "config": {"x": 1}})]
assert res["data"]["id"] == "conn-1"
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"name": "new", "source": "gmail", "config": {"y": 2}}),
)
res = _run(module.set_connector())
assert save_calls[-1]["id"] == "generated-id"
assert save_calls[-1]["tenant_id"] == "tenant-1"
assert save_calls[-1]["input_type"] == module.InputType.POLL
assert res["data"]["id"] == "generated-id"
list_res = module.list_connector()
assert list_res["data"] == [{"id": "listed", "tenant": "tenant-1"}]
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda _cid: (False, None))
missing_res = module.get_connector("missing")
assert missing_res["message"] == "Can't find this Connector!"
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, _FakeConnectorRecord({"id": cid})))
found_res = module.get_connector("conn-2")
assert found_res["data"]["id"] == "conn-2"
_set_request(module, args={"page": "2", "page_size": "7"})
logs_res = module.list_logs("conn-log")
assert logs_res["data"] == {"total": 9, "logs": [{"id": "log-1"}]}
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": True}))
assert _run(module.resume("conn-r1"))["data"] is True
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": False}))
assert _run(module.resume("conn-r2"))["data"] is True
assert ("conn-r1", module.TaskStatus.SCHEDULE) in resume_calls
assert ("conn-r2", module.TaskStatus.CANCEL) in resume_calls
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"kb_id": "kb-1"}))
monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: "rebuild-failed")
failed_rebuild = _run(module.rebuild("conn-rb"))
assert failed_rebuild["code"] == module.RetCode.SERVER_ERROR
assert failed_rebuild["data"] is False
monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: None)
ok_rebuild = _run(module.rebuild("conn-rb"))
assert ok_rebuild["data"] is True
rm_res = module.rm_connector("conn-rm")
assert rm_res["data"] is True
assert ("conn-rm", module.TaskStatus.CANCEL) in resume_calls
assert delete_calls == ["conn-rm"]
@pytest.mark.p2
def test_connector_oauth_helper_functions(monkeypatch):
module = _load_connector_app(monkeypatch)
assert module._web_state_cache_key("flow-a", "gmail") == "gmail_web_flow_state:flow-a"
assert module._web_result_cache_key("flow-b", "google-drive") == "google-drive_web_flow_result:flow-b"
creds_dict = {"web": {"client_id": "id"}}
assert module._load_credentials(creds_dict) == creds_dict
assert module._load_credentials(json.dumps(creds_dict)) == creds_dict
with pytest.raises(ValueError, match="Invalid Google credentials JSON"):
module._load_credentials("{not-json")
assert module._get_web_client_config(creds_dict) == {"web": {"client_id": "id"}}
with pytest.raises(ValueError, match="must include a 'web'"):
module._get_web_client_config({"installed": {"client_id": "id"}})
popup_ok = _run(module._render_web_oauth_popup("flow-1", True, "done", "gmail"))
assert popup_ok.status_code == 200
assert popup_ok.headers["Content-Type"] == "text/html; charset=utf-8"
assert "Authorization complete" in popup_ok.body
assert "ragflow-gmail-oauth" in popup_ok.body
popup_error = _run(module._render_web_oauth_popup("flow-2", False, "<denied>", "google-drive"))
assert popup_error.status_code == 200
assert "Authorization failed" in popup_error.body
assert "&lt;denied&gt;" in popup_error.body
@pytest.mark.p2
def test_start_google_web_oauth_matrix(monkeypatch):
module = _load_connector_app(monkeypatch)
redis = _FakeRedis()
monkeypatch.setattr(module, "REDIS_CONN", redis)
monkeypatch.setattr(module.time, "time", lambda: 1700000000)
flow_calls = []
def _from_client_config(client_config, scopes):
flow = _FakeFlow(client_config, scopes)
flow_calls.append(flow)
return flow
monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config))
_set_request(module, args={"type": "invalid"})
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{}"}))
invalid_type = _run(module.start_google_web_oauth())
assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR
monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "")
_set_request(module, args={"type": "gmail"})
missing_redirect = _run(module.start_google_web_oauth())
assert missing_redirect["code"] == module.RetCode.SERVER_ERROR
monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "https://example.com/gmail")
monkeypatch.setattr(module, "GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "https://example.com/drive")
_set_request(module, args={"type": "google-drive"})
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{invalid-json"}))
invalid_credentials = _run(module.start_google_web_oauth())
assert invalid_credentials["code"] == module.RetCode.ARGUMENT_ERROR
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id"}, "refresh_token": "rt"})}),
)
has_refresh_token = _run(module.start_google_web_oauth())
assert has_refresh_token["code"] == module.RetCode.ARGUMENT_ERROR
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": json.dumps({"installed": {"x": 1}})}))
missing_web = _run(module.start_google_web_oauth())
assert missing_web["code"] == module.RetCode.ARGUMENT_ERROR
ids = iter(["flow-gmail", "flow-drive"])
monkeypatch.setattr(module.uuid, "uuid4", lambda: next(ids))
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id", "client_secret": "secret"}})}),
)
_set_request(module, args={"type": "gmail"})
gmail_ok = _run(module.start_google_web_oauth())
assert gmail_ok["code"] == 0
assert gmail_ok["data"]["flow_id"] == "flow-gmail"
assert gmail_ok["data"]["authorization_url"].endswith("flow-gmail")
_set_request(module, args={})
drive_ok = _run(module.start_google_web_oauth())
assert drive_ok["code"] == 0
assert drive_ok["data"]["flow_id"] == "flow-drive"
assert drive_ok["data"]["authorization_url"].endswith("flow-drive")
assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GMAIL] for call in flow_calls)
assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE] for call in flow_calls)
assert "gmail_web_flow_state:flow-gmail" in redis.store
assert "google-drive_web_flow_state:flow-drive" in redis.store
@pytest.mark.p2
def test_google_web_oauth_callbacks_matrix(monkeypatch):
module = _load_connector_app(monkeypatch)
flow_calls = []
def _from_client_config(client_config, scopes):
flow = _FakeFlow(client_config, scopes)
flow_calls.append(flow)
return flow
monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config))
callback_specs = [
(
module.google_gmail_web_oauth_callback,
"gmail",
module.GMAIL_WEB_OAUTH_REDIRECT_URI,
module.GOOGLE_SCOPES[module.DocumentSource.GMAIL],
),
(
module.google_drive_web_oauth_callback,
"google-drive",
module.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI,
module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE],
),
]
for callback, source, expected_redirect, expected_scopes in callback_specs:
redis = _FakeRedis()
monkeypatch.setattr(module, "REDIS_CONN", redis)
_set_request(module, args={})
missing_state = _run(callback())
assert "Missing OAuth state parameter." in missing_state.body
_set_request(module, args={"state": "sid"})
expired_state = _run(callback())
assert "Authorization session expired" in expired_state.body
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({"user_id": "tenant-1"})
_set_request(module, args={"state": "sid"})
invalid_state = _run(callback())
assert "Authorization session was invalid" in invalid_state.body
assert module._web_state_cache_key("sid", source) in redis.deleted
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
"user_id": "tenant-1",
"client_config": {"web": {"client_id": "cid"}},
})
_set_request(module, args={"state": "sid", "error": "denied", "error_description": "permission denied"})
oauth_error = _run(callback())
assert "permission denied" in oauth_error.body
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
"user_id": "tenant-1",
"client_config": {"web": {"client_id": "cid"}},
})
_set_request(module, args={"state": "sid"})
missing_code = _run(callback())
assert "Missing authorization code" in missing_code.body
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
"user_id": "tenant-1",
"client_config": {"web": {"client_id": "cid"}},
})
_set_request(module, args={"state": "sid", "code": "code-123"})
success = _run(callback())
assert "Authorization completed successfully." in success.body
result_key = module._web_result_cache_key("sid", source)
assert result_key in redis.store
assert module._web_state_cache_key("sid", source) in redis.deleted
assert flow_calls[-1].redirect_uri == expected_redirect
assert flow_calls[-1].scopes == expected_scopes
assert flow_calls[-1].token_code == "code-123"
@pytest.mark.p2
def test_poll_google_web_result_matrix(monkeypatch):
module = _load_connector_app(monkeypatch)
redis = _FakeRedis()
monkeypatch.setattr(module, "REDIS_CONN", redis)
_set_request(module, args={"type": "invalid"}, json_body={"flow_id": "flow-1"})
invalid_type = _run(module.poll_google_web_result())
assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
pending = _run(module.poll_google_web_result())
assert pending["code"] == module.RetCode.RUNNING
redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps(
{"user_id": "another-user", "credentials": "token-x"}
)
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
permission_error = _run(module.poll_google_web_result())
assert permission_error["code"] == module.RetCode.PERMISSION_ERROR
redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps(
{"user_id": "tenant-1", "credentials": "token-ok"}
)
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
success = _run(module.poll_google_web_result())
assert success["code"] == 0
assert success["data"] == {"credentials": "token-ok"}
assert module._web_result_cache_key("flow-1", "gmail") in redis.deleted
@pytest.mark.p2
def test_box_oauth_start_callback_and_poll_matrix(monkeypatch):
module = _load_connector_app(monkeypatch)
redis = _FakeRedis()
monkeypatch.setattr(module, "REDIS_CONN", redis)
created_auth = []
class _TrackingBoxOAuth(_FakeBoxOAuth):
def __init__(self, config):
super().__init__(config)
created_auth.append(self)
monkeypatch.setattr(module, "BoxOAuth", _TrackingBoxOAuth)
monkeypatch.setattr(module.uuid, "uuid4", lambda: "flow-box")
monkeypatch.setattr(module.time, "time", lambda: 1800000000)
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({}))
missing_params = _run(module.start_box_web_oauth())
assert missing_params["code"] == module.RetCode.ARGUMENT_ERROR
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue({"client_id": "cid", "client_secret": "sec", "redirect_uri": "https://box.local/callback"}),
)
start_ok = _run(module.start_box_web_oauth())
assert start_ok["code"] == 0
assert start_ok["data"]["flow_id"] == "flow-box"
assert "authorization_url" in start_ok["data"]
assert module._web_state_cache_key("flow-box", "box") in redis.store
_set_request(module, args={})
missing_state = _run(module.box_web_oauth_callback())
assert "Missing OAuth parameters." in missing_state.body
_set_request(module, args={"state": "flow-box"})
missing_code = _run(module.box_web_oauth_callback())
assert "Missing authorization code from Box." in missing_code.body
redis.store[module._web_state_cache_key("flow-null", "box")] = "null"
_set_request(module, args={"state": "flow-null", "code": "abc"})
invalid_session = _run(module.box_web_oauth_callback())
assert invalid_session["code"] == module.RetCode.ARGUMENT_ERROR
redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps(
{"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}
)
_set_request(module, args={"state": "flow-box", "code": "abc", "error": "access_denied", "error_description": "denied"})
callback_error = _run(module.box_web_oauth_callback())
assert "denied" in callback_error.body
redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps(
{"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}
)
_set_request(module, args={"state": "flow-ok", "code": "code-ok"})
callback_success = _run(module.box_web_oauth_callback())
assert "Authorization completed successfully." in callback_success.body
assert created_auth[-1].exchange_code == "code-ok"
assert module._web_result_cache_key("flow-ok", "box") in redis.store
assert module._web_state_cache_key("flow-ok", "box") in redis.deleted
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"flow_id": "flow-ok"}))
redis.store.pop(module._web_result_cache_key("flow-ok", "box"), None)
pending = _run(module.poll_box_web_result())
assert pending["code"] == module.RetCode.RUNNING
redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps({"user_id": "another-user"})
permission_error = _run(module.poll_box_web_result())
assert permission_error["code"] == module.RetCode.PERMISSION_ERROR
redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps(
{"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"}
)
poll_success = _run(module.poll_box_web_result())
assert poll_success["code"] == 0
assert poll_success["data"]["credentials"]["access_token"] == "at"
assert module._web_result_cache_key("flow-ok", "box") in redis.deleted

View File

@ -578,7 +578,189 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
@pytest.mark.p2
def test_tts_request_parse_entry(monkeypatch):
module = _load_conversation_module(monkeypatch)
_set_request_json(monkeypatch, module, {"text": "hello"})
_set_request_json(monkeypatch, module, {"text": "A。B"})
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
res = _run(module.tts())
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "tts_id": ""}])
res = _run(module.tts())
assert res["message"] == "No default TTS model is set"
class _TTSOk:
def tts(self, txt):
if not txt:
return []
yield f"chunk-{txt}".encode("utf-8")
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "tts_id": "tts-x"}])
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
resp = _run(module.tts())
assert resp.mimetype == "audio/mpeg"
assert resp.headers.get("Cache-Control") == "no-cache"
assert resp.headers.get("Connection") == "keep-alive"
assert resp.headers.get("X-Accel-Buffering") == "no"
stream_text = _run(_read_sse_text(resp))
assert "chunk-A" in stream_text
assert "chunk-B" in stream_text
class _TTSErr:
def tts(self, _txt):
raise RuntimeError("tts boom")
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSErr())
resp = _run(module.tts())
stream_text = _run(_read_sse_text(resp))
assert '"code": 500' in stream_text
assert "**ERROR**: tts boom" in stream_text
@pytest.mark.p2
def test_delete_msg_and_thumbup_matrix_unit(monkeypatch):
module = _load_conversation_module(monkeypatch)
updates = []
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda conv_id, payload: updates.append((conv_id, payload)) or True)
_set_request_json(monkeypatch, module, {"conversation_id": "missing", "message_id": "pair-1"})
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
res = _run(module.delete_msg.__wrapped__())
assert res["message"] == "Conversation not found!"
conv = _DummyConversation(
conv_id="conv-del",
message=[
{"id": "other", "role": "user"},
{"id": "pair-1", "role": "user"},
{"id": "pair-1", "role": "assistant"},
],
reference=[{"chunks": [{"id": "c1"}]}],
)
_set_request_json(monkeypatch, module, {"conversation_id": "conv-del", "message_id": "pair-1"})
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
res = _run(module.delete_msg.__wrapped__())
assert res["code"] == 0
assert [m["id"] for m in res["data"]["message"]] == ["other"]
assert res["data"]["reference"] == []
assert updates[-1][0] == "conv-del"
_set_request_json(monkeypatch, module, {"conversation_id": "missing", "message_id": "assistant-1", "thumbup": True})
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
res = _run(module.thumbup.__wrapped__())
assert res["message"] == "Conversation not found!"
conv_up = _DummyConversation(
conv_id="conv-up",
message=[{"id": "assistant-1", "role": "assistant", "feedback": "old"}],
)
_set_request_json(monkeypatch, module, {"conversation_id": "conv-up", "message_id": "assistant-1", "thumbup": True})
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv_up))
res = _run(module.thumbup.__wrapped__())
assert res["code"] == 0
assert res["data"]["message"][0]["thumbup"] is True
assert "feedback" not in res["data"]["message"][0]
conv_down = _DummyConversation(conv_id="conv-down", message=[{"id": "assistant-2", "role": "assistant"}])
_set_request_json(
monkeypatch,
module,
{"conversation_id": "conv-down", "message_id": "assistant-2", "thumbup": False, "feedback": "needs sources"},
)
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv_down))
res = _run(module.thumbup.__wrapped__())
assert res["code"] == 0
assert res["data"]["message"][0]["thumbup"] is False
assert res["data"]["message"][0]["feedback"] == "needs sources"
@pytest.mark.p2
def test_ask_about_stream_search_config_matrix_unit(monkeypatch):
module = _load_conversation_module(monkeypatch)
_set_request_json(monkeypatch, module, {"question": "q", "kb_ids": ["kb-1"], "search_id": "search-1"})
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"mode": "test"}})
captured = {}
async def _fake_async_ask(question, kb_ids, uid, search_config=None):
captured["question"] = question
captured["kb_ids"] = kb_ids
captured["uid"] = uid
captured["search_config"] = search_config
yield {"answer": "first"}
raise RuntimeError("ask boom")
monkeypatch.setattr(module, "async_ask", _fake_async_ask)
resp = _run(module.ask_about.__wrapped__())
assert resp.headers["Content-Type"] == "text/event-stream; charset=utf-8"
sse_text = _run(_read_sse_text(resp))
assert '"answer": "first"' in sse_text
assert "**ERROR**: ask boom" in sse_text
assert '"data": true' in sse_text.lower()
assert captured == {"question": "q", "kb_ids": ["kb-1"], "uid": "user-1", "search_config": {"mode": "test"}}
@pytest.mark.p2
def test_mindmap_and_related_questions_matrix_unit(monkeypatch):
module = _load_conversation_module(monkeypatch)
def _search_detail(_sid):
return {
"tenant_id": "tenant-x",
"search_config": {
"kb_ids": ["kb-2", "kb-3"],
"chat_id": "chat-x",
"llm_setting": {"temperature": 0.2, "parameter": {"k": "v"}},
},
}
monkeypatch.setattr(module.SearchService, "get_detail", _search_detail)
_set_request_json(monkeypatch, module, {"question": "mindmap-q", "kb_ids": ["kb-1", "kb-2"], "search_id": "search-1"})
mindmap_calls = {}
async def _gen_ok(question, kb_ids, tenant_id, search_config):
mindmap_calls["question"] = question
mindmap_calls["kb_ids"] = set(kb_ids)
mindmap_calls["tenant_id"] = tenant_id
mindmap_calls["search_config"] = search_config
return {"nodes": [question]}
monkeypatch.setattr(module, "gen_mindmap", _gen_ok)
res = _run(module.mindmap.__wrapped__())
assert res["code"] == 0
assert res["data"] == {"nodes": ["mindmap-q"]}
assert mindmap_calls["kb_ids"] == {"kb-1", "kb-2", "kb-3"}
assert mindmap_calls["tenant_id"] == "tenant-x"
assert set(mindmap_calls["search_config"]["kb_ids"]) == {"kb-1", "kb-2", "kb-3"}
async def _gen_error(*_args, **_kwargs):
return {"error": "mindmap boom"}
monkeypatch.setattr(module, "gen_mindmap", _gen_error)
res = _run(module.mindmap.__wrapped__())
assert "mindmap boom" in res["message"]
llm_calls = {}
class _FakeChat:
async def async_chat(self, prompt, messages, options):
llm_calls["prompt"] = prompt
llm_calls["messages"] = messages
llm_calls["options"] = options
return "1. Alpha\n2. Beta\nignored"
def _fake_bundle(tenant_id, llm_type, chat_id):
llm_calls["bundle"] = (tenant_id, llm_type, chat_id)
return _FakeChat()
monkeypatch.setattr(module, "LLMBundle", _fake_bundle)
monkeypatch.setattr(module, "load_prompt", lambda name: f"prompt-{name}")
_set_request_json(monkeypatch, module, {"question": "solar", "search_id": "search-1"})
res = _run(module.related_questions.__wrapped__())
assert res["code"] == 0
assert res["data"] == ["Alpha", "Beta"]
assert llm_calls["bundle"][0] == "user-1"
assert llm_calls["bundle"][2] == "chat-x"
assert llm_calls["options"] == {"temperature": 0.2}
assert llm_calls["prompt"] == "prompt-related_question"
assert "Keywords: solar" in llm_calls["messages"][0]["content"]

View File

@ -0,0 +1,778 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import functools
import importlib.util
import inspect
import json
import os
import sys
from copy import deepcopy
from enum import Enum
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _DummyArgs(dict):
def get(self, key, default=None, type=None):
value = super().get(key, default)
if value is None or type is None:
return value
try:
return type(value)
except (TypeError, ValueError):
return default
class _Field:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return (self.name, "==", other)
class _KB:
def __init__(
self,
*,
kb_id="kb-1",
name="old",
tenant_id="tenant-1",
parser_id="naive",
parser_config=None,
embd_id="embd-1",
chunk_num=0,
pagerank=0,
graphrag_task_id="",
raptor_task_id="",
):
self.id = kb_id
self.name = name
self.tenant_id = tenant_id
self.parser_id = parser_id
self.parser_config = parser_config or {}
self.embd_id = embd_id
self.chunk_num = chunk_num
self.pagerank = pagerank
self.graphrag_task_id = graphrag_task_id
self.raptor_task_id = raptor_task_id
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"tenant_id": self.tenant_id,
"parser_id": self.parser_id,
"parser_config": deepcopy(self.parser_config),
"embd_id": self.embd_id,
"pagerank": self.pagerank,
}
def _run(coro):
return asyncio.run(coro)
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _set_request_args(monkeypatch, module, args):
monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs(args)))
def _patch_json_parser(monkeypatch, module, payload_state, err_state=None):
async def _parse_json(*_args, **_kwargs):
return deepcopy(payload_state), err_state
monkeypatch.setattr(module, "validate_and_parse_json_request", _parse_json)
def _load_dataset_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
quart_mod = ModuleType("quart")
quart_mod.Request = type("Request", (), {})
quart_mod.request = SimpleNamespace(args=_DummyArgs())
monkeypatch.setitem(sys.modules, "quart", quart_mod)
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
utils_pkg = ModuleType("api.utils")
utils_pkg.__path__ = [str(repo_root / "api" / "utils")]
monkeypatch.setitem(sys.modules, "api.utils", utils_pkg)
api_pkg.utils = utils_pkg
apps_pkg = ModuleType("api.apps")
apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
api_pkg.apps = apps_pkg
sdk_pkg = ModuleType("api.apps.sdk")
sdk_pkg.__path__ = [str(repo_root / "api" / "apps" / "sdk")]
monkeypatch.setitem(sys.modules, "api.apps.sdk", sdk_pkg)
apps_pkg.sdk = sdk_pkg
db_pkg = ModuleType("api.db")
db_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
api_pkg.db = db_pkg
db_models_mod = ModuleType("api.db.db_models")
db_models_mod.File = SimpleNamespace(
source_type=_Field("source_type"),
id=_Field("id"),
type=_Field("type"),
name=_Field("name"),
)
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
document_service_mod = ModuleType("api.db.services.document_service")
class _StubDocumentService:
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def remove_document(*_args, **_kwargs):
return True
@staticmethod
def get_by_kb_id(**_kwargs):
return [], 0
document_service_mod.DocumentService = _StubDocumentService
document_service_mod.queue_raptor_o_graphrag_tasks = lambda **_kwargs: "task-queued"
monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
services_pkg.document_service = document_service_mod
file2document_service_mod = ModuleType("api.db.services.file2document_service")
class _StubFile2DocumentService:
@staticmethod
def get_by_document_id(_doc_id):
return [SimpleNamespace(file_id="file-1")]
@staticmethod
def delete_by_document_id(_doc_id):
return None
file2document_service_mod.File2DocumentService = _StubFile2DocumentService
monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_service_mod)
services_pkg.file2document_service = file2document_service_mod
file_service_mod = ModuleType("api.db.services.file_service")
class _StubFileService:
@staticmethod
def filter_delete(_filters):
return None
file_service_mod.FileService = _StubFileService
monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod)
services_pkg.file_service = file_service_mod
knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service")
class _StubKnowledgebaseService:
@staticmethod
def create_with_name(**_kwargs):
return True, {"id": "kb-1"}
@staticmethod
def save(**_kwargs):
return True
@staticmethod
def get_by_id(_kb_id):
return True, _KB()
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def get_or_none(**_kwargs):
return _KB()
@staticmethod
def delete_by_id(_kb_id):
return True
@staticmethod
def update_by_id(_kb_id, _payload):
return True
@staticmethod
def get_kb_by_id(_kb_id, _tenant_id):
return [SimpleNamespace(id=_kb_id)]
@staticmethod
def get_kb_by_name(_name, _tenant_id):
return [SimpleNamespace(name=_name)]
@staticmethod
def get_list(*_args, **_kwargs):
return [], 0
@staticmethod
def accessible(_dataset_id, _tenant_id):
return True
knowledgebase_service_mod.KnowledgebaseService = _StubKnowledgebaseService
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod)
services_pkg.knowledgebase_service = knowledgebase_service_mod
task_service_mod = ModuleType("api.db.services.task_service")
class _StubTaskService:
@staticmethod
def get_by_id(_task_id):
return False, None
task_service_mod.GRAPH_RAPTOR_FAKE_DOC_ID = "fake-doc"
task_service_mod.TaskService = _StubTaskService
monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service_mod)
services_pkg.task_service = task_service_mod
user_service_mod = ModuleType("api.db.services.user_service")
class _StubTenantService:
@staticmethod
def get_by_id(_tenant_id):
return True, SimpleNamespace(embd_id="embd-default")
@staticmethod
def get_joined_tenants_by_user_id(_tenant_id):
return [{"tenant_id": "tenant-1"}]
user_service_mod.TenantService = _StubTenantService
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
services_pkg.user_service = user_service_mod
constants_mod = ModuleType("common.constants")
class _RetCode:
SUCCESS = 0
ARGUMENT_ERROR = 101
DATA_ERROR = 102
AUTHENTICATION_ERROR = 108
class _FileSource:
KNOWLEDGEBASE = "knowledgebase"
class _StatusEnum(Enum):
VALID = "valid"
constants_mod.RetCode = _RetCode
constants_mod.FileSource = _FileSource
constants_mod.StatusEnum = _StatusEnum
constants_mod.PAGERANK_FLD = "pagerank"
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
common_pkg.settings = SimpleNamespace(
docStoreConn=SimpleNamespace(
delete_idx=lambda *_args, **_kwargs: None,
delete=lambda *_args, **_kwargs: None,
update=lambda *_args, **_kwargs: None,
index_exist=lambda *_args, **_kwargs: False,
),
retriever=SimpleNamespace(search=lambda *_args, **_kwargs: _AwaitableValue(SimpleNamespace(ids=[], field={}))),
)
monkeypatch.setitem(sys.modules, "common", common_pkg)
api_utils_mod = ModuleType("api.utils.api_utils")
def _deep_merge(base, updates):
merged = deepcopy(base)
for key, value in updates.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = _deep_merge(merged[key], value)
else:
merged[key] = value
return merged
def _get_result(*, data=None, message="", code=_RetCode.SUCCESS, total=None):
payload = {"code": code, "data": data, "message": message}
if total is not None:
payload["total"] = total
return payload
def _get_error_argument_result(message=""):
return _get_result(code=_RetCode.ARGUMENT_ERROR, message=message)
def _get_error_data_result(message=""):
return _get_result(code=_RetCode.DATA_ERROR, message=message)
def _get_error_permission_result(message=""):
return _get_result(code=_RetCode.AUTHENTICATION_ERROR, message=message)
def _token_required(func):
@functools.wraps(func)
async def _async_wrapper(*args, **kwargs):
return await func(*args, **kwargs)
@functools.wraps(func)
def _sync_wrapper(*args, **kwargs):
return func(*args, **kwargs)
return _async_wrapper if asyncio.iscoroutinefunction(func) else _sync_wrapper
api_utils_mod.deep_merge = _deep_merge
api_utils_mod.get_error_argument_result = _get_error_argument_result
api_utils_mod.get_error_data_result = _get_error_data_result
api_utils_mod.get_error_permission_result = _get_error_permission_result
api_utils_mod.get_parser_config = lambda _chunk_method, _unused: {"auto": True}
api_utils_mod.get_result = _get_result
api_utils_mod.remap_dictionary_keys = lambda data: data
api_utils_mod.token_required = _token_required
api_utils_mod.verify_embedding_availability = lambda _embd_id, _tenant_id: (True, None)
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
async def _parse_json(*_args, **_kwargs):
return {}, None
def _parse_args(*_args, **_kwargs):
return {"name": "", "page": 1, "page_size": 30, "orderby": "create_time", "desc": True}, None
validation_spec = importlib.util.spec_from_file_location(
"api.utils.validation_utils", repo_root / "api" / "utils" / "validation_utils.py"
)
validation_mod = importlib.util.module_from_spec(validation_spec)
monkeypatch.setitem(sys.modules, "api.utils.validation_utils", validation_mod)
validation_spec.loader.exec_module(validation_mod)
validation_mod.validate_and_parse_json_request = _parse_json
validation_mod.validate_and_parse_request_args = _parse_args
rag_pkg = ModuleType("rag")
rag_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
rag_nlp_pkg = ModuleType("rag.nlp")
rag_nlp_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp_pkg)
search_mod = ModuleType("rag.nlp.search")
search_mod.index_name = lambda _tenant_id: "idx"
monkeypatch.setitem(sys.modules, "rag.nlp.search", search_mod)
rag_nlp_pkg.search = search_mod
module_name = "test_dataset_sdk_routes_unit_module"
module_path = repo_root / "api" / "apps" / "sdk" / "dataset.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_create_route_error_matrix_unit(monkeypatch):
module = _load_dataset_module(monkeypatch)
req_state = {"name": "kb"}
_patch_json_parser(monkeypatch, module, req_state)
monkeypatch.setattr(module.KnowledgebaseService, "create_with_name", lambda **_kwargs: (False, {"code": 777, "message": "early"}))
res = _run(inspect.unwrap(module.create)("tenant-1"))
assert res["code"] == 777, res
monkeypatch.setattr(module.KnowledgebaseService, "create_with_name", lambda **_kwargs: (True, {"id": "kb-1"}))
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (False, None))
res = _run(inspect.unwrap(module.create)("tenant-1"))
assert res["message"] == "Tenant not found", res
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (True, SimpleNamespace(embd_id="embd-1")))
monkeypatch.setattr(module.KnowledgebaseService, "save", lambda **_kwargs: False)
res = _run(inspect.unwrap(module.create)("tenant-1"))
assert res["code"] == module.RetCode.DATA_ERROR, res
monkeypatch.setattr(module.KnowledgebaseService, "save", lambda **_kwargs: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
res = _run(inspect.unwrap(module.create)("tenant-1"))
assert "Dataset created failed" in res["message"], res
monkeypatch.setattr(module.KnowledgebaseService, "save", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("save boom")))
res = _run(inspect.unwrap(module.create)("tenant-1"))
assert res["message"] == "Database operation failed", res
@pytest.mark.p2
def test_delete_route_error_summary_matrix_unit(monkeypatch):
module = _load_dataset_module(monkeypatch)
req_state = {"ids": ["kb-1"]}
_patch_json_parser(monkeypatch, module, req_state)
kb = _KB(kb_id="kb-1", name="kb-1", tenant_id="tenant-1")
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", lambda **_kwargs: kb)
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [SimpleNamespace(id="doc-1")])
monkeypatch.setattr(module.DocumentService, "remove_document", lambda *_args, **_kwargs: False)
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(delete_idx=lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("drop failed"))))
monkeypatch.setattr(module.KnowledgebaseService, "delete_by_id", lambda _kb_id: False)
res = _run(inspect.unwrap(module.delete)("tenant-1"))
assert res["code"] == module.RetCode.DATA_ERROR, res
assert "Successfully deleted 0 datasets" in res["message"], res
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(delete_idx=lambda *_args, **_kwargs: None))
monkeypatch.setattr(module.KnowledgebaseService, "delete_by_id", lambda _kb_id: True)
res = _run(inspect.unwrap(module.delete)("tenant-1"))
assert res["code"] == module.RetCode.SUCCESS, res
assert res["data"]["success_count"] == 1, res
assert res["data"]["errors"], res
req_state["ids"] = None
monkeypatch.setattr(
module.KnowledgebaseService,
"query",
lambda **_kwargs: (_ for _ in ()).throw(module.OperationalError("db down")),
)
res = _run(inspect.unwrap(module.delete)("tenant-1"))
assert res["code"] == module.RetCode.DATA_ERROR, res
assert res["message"] == "Database operation failed", res
@pytest.mark.p2
def test_update_route_branch_matrix_unit(monkeypatch):
module = _load_dataset_module(monkeypatch)
req_state = {"name": "new"}
_patch_json_parser(monkeypatch, module, req_state)
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", lambda **_kwargs: None)
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
kb = _KB(kb_id="kb-1", name="old", chunk_num=0)
def _get_or_none_duplicate(**kwargs):
if kwargs.get("id"):
return kb
if kwargs.get("name"):
return SimpleNamespace(id="dup")
return None
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", _get_or_none_duplicate)
req_state.clear()
req_state.update({"name": "new"})
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
assert "already exists" in res["message"], res
kb_chunked = _KB(kb_id="kb-1", name="old", chunk_num=2, embd_id="embd-1")
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", lambda **kwargs: kb_chunked if kwargs.get("id") else None)
req_state.clear()
req_state.update({"embd_id": "embd-2"})
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
assert "chunk_num" in res["message"], res
kb_rank = _KB(kb_id="kb-1", name="old", pagerank=0)
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", lambda **kwargs: kb_rank if kwargs.get("id") else None)
req_state.clear()
req_state.update({"pagerank": 3})
os.environ["DOC_ENGINE"] = "infinity"
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
assert "doc_engine" in res["message"], res
os.environ.pop("DOC_ENGINE", None)
update_calls = []
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(update=lambda *args, **_kwargs: update_calls.append(args)))
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(kb_id="kb-1", pagerank=3)))
req_state.clear()
req_state.update({"pagerank": 3})
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
assert res["code"] == module.RetCode.SUCCESS, res
assert update_calls and update_calls[-1][0] == {"kb_id": "kb-1"}, update_calls
update_calls.clear()
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", lambda **kwargs: _KB(kb_id="kb-1", pagerank=3) if kwargs.get("id") else None)
req_state.clear()
req_state.update({"pagerank": 0})
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
assert res["code"] == module.RetCode.SUCCESS, res
assert update_calls and update_calls[-1][0] == {"exists": module.PAGERANK_FLD}, update_calls
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: False)
req_state.clear()
req_state.update({"description": "changed"})
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
assert "Update dataset error" in res["message"], res
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
assert "Dataset created failed" in res["message"], res
monkeypatch.setattr(
module.KnowledgebaseService,
"get_or_none",
lambda **_kwargs: (_ for _ in ()).throw(module.OperationalError("update down")),
)
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
assert res["message"] == "Database operation failed", res
@pytest.mark.p2
def test_list_knowledge_graph_delete_kg_matrix_unit(monkeypatch):
module = _load_dataset_module(monkeypatch)
_set_request_args(monkeypatch, module, {"id": "", "name": "", "page": 1, "page_size": 30, "orderby": "create_time", "desc": True})
monkeypatch.setattr(
module,
"validate_and_parse_request_args",
lambda *_args, **_kwargs: ({"name": "", "page": 1, "page_size": 30, "orderby": "create_time", "desc": True}, None),
)
monkeypatch.setattr(
module.KnowledgebaseService,
"get_list",
lambda *_args, **_kwargs: (_ for _ in ()).throw(module.OperationalError("list down")),
)
res = module.list_datasets("tenant-1")
assert res["code"] == module.RetCode.DATA_ERROR, res
assert res["message"] == "Database operation failed", res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
res = _run(inspect.unwrap(module.knowledge_graph)("tenant-1", "kb-1"))
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(tenant_id="tenant-1")))
monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx")
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(index_exist=lambda *_args, **_kwargs: False))
res = _run(inspect.unwrap(module.knowledge_graph)("tenant-1", "kb-1"))
assert res["data"] == {"graph": {}, "mind_map": {}}, res
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(index_exist=lambda *_args, **_kwargs: True))
class _EmptyRetriever:
async def search(self, *_args, **_kwargs):
return SimpleNamespace(ids=[], field={})
monkeypatch.setattr(module.settings, "retriever", _EmptyRetriever())
res = _run(inspect.unwrap(module.knowledge_graph)("tenant-1", "kb-1"))
assert res["data"] == {"graph": {}, "mind_map": {}}, res
class _BadRetriever:
async def search(self, *_args, **_kwargs):
return SimpleNamespace(ids=["bad"], field={"bad": {"knowledge_graph_kwd": "graph", "content_with_weight": "{bad"}})
monkeypatch.setattr(module.settings, "retriever", _BadRetriever())
res = _run(inspect.unwrap(module.knowledge_graph)("tenant-1", "kb-1"))
assert res["code"] == module.RetCode.SUCCESS, res
assert res["data"]["graph"] == {}, res
payload = {
"nodes": [{"id": "n2", "pagerank": 2}, {"id": "n1", "pagerank": 5}],
"edges": [
{"source": "n1", "target": "n2", "weight": 2},
{"source": "n1", "target": "n1", "weight": 10},
{"source": "n1", "target": "n3", "weight": 9},
],
}
class _GoodRetriever:
async def search(self, *_args, **_kwargs):
return SimpleNamespace(ids=["good"], field={"good": {"knowledge_graph_kwd": "graph", "content_with_weight": json.dumps(payload)}})
monkeypatch.setattr(module.settings, "retriever", _GoodRetriever())
res = _run(inspect.unwrap(module.knowledge_graph)("tenant-1", "kb-1"))
assert res["code"] == module.RetCode.SUCCESS, res
assert len(res["data"]["graph"]["nodes"]) == 2, res
assert len(res["data"]["graph"]["edges"]) == 1, res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
res = inspect.unwrap(module.delete_knowledge_graph)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
@pytest.mark.p2
def test_run_trace_graphrag_matrix_unit(monkeypatch):
module = _load_dataset_module(monkeypatch)
warnings = []
monkeypatch.setattr(module.logging, "warning", lambda msg, *_args, **_kwargs: warnings.append(msg))
res = inspect.unwrap(module.run_graphrag)("tenant-1", "")
assert 'Dataset ID' in res["message"], res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
res = inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
res = inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")
assert "Invalid Dataset ID" in res["message"], res
stale_kb = _KB(kb_id="kb-1", graphrag_task_id="task-old")
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, stale_kb))
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None))
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}], 1))
monkeypatch.setattr(module, "queue_raptor_o_graphrag_tasks", lambda **_kwargs: "task-new")
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True)
res = inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.SUCCESS, res
assert any("GraphRAG" in msg for msg in warnings), warnings
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(progress=0)))
res = inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")
assert "already running" in res["message"], res
warnings.clear()
queue_calls = {}
no_task_kb = _KB(kb_id="kb-1", graphrag_task_id="")
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, no_task_kb))
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None))
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}, {"id": "doc-2"}], 2))
def _queue(**kwargs):
queue_calls.update(kwargs)
return "queued-id"
monkeypatch.setattr(module, "queue_raptor_o_graphrag_tasks", _queue)
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: False)
res = inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.SUCCESS, res
assert res["data"]["graphrag_task_id"] == "queued-id", res
assert queue_calls["doc_ids"] == ["doc-1", "doc-2"], queue_calls
assert any("Cannot save graphrag_task_id" in msg for msg in warnings), warnings
res = inspect.unwrap(module.trace_graphrag)("tenant-1", "")
assert 'Dataset ID' in res["message"], res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1")
assert "Invalid Dataset ID" in res["message"], res
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(kb_id="kb-1", graphrag_task_id="task-1")))
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None))
res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.SUCCESS, res
assert res["data"] == {}, res
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(to_dict=lambda: {"id": _task_id, "progress": 1})))
res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.SUCCESS, res
assert res["data"]["id"] == "task-1", res
@pytest.mark.p2
def test_run_trace_raptor_matrix_unit(monkeypatch):
module = _load_dataset_module(monkeypatch)
warnings = []
monkeypatch.setattr(module.logging, "warning", lambda msg, *_args, **_kwargs: warnings.append(msg))
res = inspect.unwrap(module.run_raptor)("tenant-1", "")
assert 'Dataset ID' in res["message"], res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
res = inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
res = inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")
assert "Invalid Dataset ID" in res["message"], res
stale_kb = _KB(kb_id="kb-1", raptor_task_id="task-old")
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, stale_kb))
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None))
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}], 1))
monkeypatch.setattr(module, "queue_raptor_o_graphrag_tasks", lambda **_kwargs: "task-new")
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True)
res = inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.SUCCESS, res
assert any("RAPTOR" in msg for msg in warnings), warnings
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(progress=0)))
res = inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")
assert "already running" in res["message"], res
warnings.clear()
no_task_kb = _KB(kb_id="kb-1", raptor_task_id="")
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, no_task_kb))
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}], 1))
monkeypatch.setattr(module, "queue_raptor_o_graphrag_tasks", lambda **_kwargs: "queued-raptor")
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: False)
res = inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.SUCCESS, res
assert res["data"]["raptor_task_id"] == "queued-raptor", res
assert any("Cannot save raptor_task_id" in msg for msg in warnings), warnings
res = inspect.unwrap(module.trace_raptor)("tenant-1", "")
assert 'Dataset ID' in res["message"], res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1")
assert "Invalid Dataset ID" in res["message"], res
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(kb_id="kb-1", raptor_task_id="task-1")))
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None))
res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1")
assert "RAPTOR Task Not Found" in res["message"], res
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(to_dict=lambda: {"id": _task_id, "progress": -1})))
res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1")
assert res["code"] == module.RetCode.SUCCESS, res
assert res["data"]["id"] == "task-1", res

View File

@ -0,0 +1,549 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import inspect
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
from functools import wraps
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _Args(dict):
def get(self, key, default=None):
return super().get(key, default)
def _run(coro):
return asyncio.run(coro)
def _set_request_json(monkeypatch, module, payload):
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))
def _set_request_args(monkeypatch, module, args):
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args)))
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_dialog_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
quart_mod = ModuleType("quart")
quart_mod.request = SimpleNamespace(args=_Args())
monkeypatch.setitem(sys.modules, "quart", quart_mod)
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="tenant-1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
api_pkg.apps = apps_mod
db_pkg = ModuleType("api.db")
db_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
api_pkg.db = db_pkg
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
services_pkg.duplicate_name = lambda _checker, **kwargs: kwargs.get("name", "")
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
dialog_service_mod = ModuleType("api.db.services.dialog_service")
class _DialogService:
model = SimpleNamespace(create_time="create_time")
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def save(**_kwargs):
return True
@staticmethod
def update_by_id(*_args, **_kwargs):
return True
@staticmethod
def get_by_id(_id):
return True, SimpleNamespace(to_dict=lambda: {"id": _id, "kb_ids": []})
@staticmethod
def get_by_tenant_ids(*_args, **_kwargs):
return [], 0
@staticmethod
def update_many_by_id(_payload):
return True
dialog_service_mod.DialogService = _DialogService
monkeypatch.setitem(sys.modules, "api.db.services.dialog_service", dialog_service_mod)
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
class _TenantLLMService:
@staticmethod
def split_model_name_and_factory(embd_id):
return embd_id.split("@")
tenant_llm_service_mod.TenantLLMService = _TenantLLMService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service")
class _KnowledgebaseService:
@staticmethod
def get_by_ids(_ids):
return []
@staticmethod
def get_by_id(_id):
return False, None
@staticmethod
def query(**_kwargs):
return []
knowledgebase_service_mod.KnowledgebaseService = _KnowledgebaseService
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod)
user_service_mod = ModuleType("api.db.services.user_service")
class _TenantService:
@staticmethod
def get_by_id(_id):
return True, SimpleNamespace(llm_id="llm-default")
class _UserTenantService:
@staticmethod
def query(**_kwargs):
return [SimpleNamespace(tenant_id="tenant-1")]
user_service_mod.TenantService = _TenantService
user_service_mod.UserTenantService = _UserTenantService
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
from common.constants import RetCode
async def _default_request_json():
return {}
def _get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!"):
return {"code": code, "message": message}
def _get_json_result(code=RetCode.SUCCESS, message="success", data=None):
return {"code": code, "message": message, "data": data}
def _server_error_response(error):
return {"code": RetCode.EXCEPTION_ERROR, "message": repr(error)}
def _validate_request(*_args, **_kwargs):
def _decorator(func):
if inspect.iscoroutinefunction(func):
@wraps(func)
async def _wrapped(*func_args, **func_kwargs):
return await func(*func_args, **func_kwargs)
return _wrapped
@wraps(func)
def _wrapped(*func_args, **func_kwargs):
return func(*func_args, **func_kwargs)
return _wrapped
return _decorator
api_utils_mod.get_request_json = _default_request_json
api_utils_mod.get_data_error_result = _get_data_error_result
api_utils_mod.get_json_result = _get_json_result
api_utils_mod.server_error_response = _server_error_response
api_utils_mod.validate_request = _validate_request
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
module_name = "test_dialog_routes_unit_module"
module_path = repo_root / "api" / "apps" / "dialog_app.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_set_dialog_branch_matrix_unit(monkeypatch):
module = _load_dialog_module(monkeypatch)
handler = inspect.unwrap(module.set_dialog)
_set_request_json(monkeypatch, module, {"name": 1, "prompt_config": {"system": "", "parameters": []}})
res = _run(handler())
assert res["message"] == "Dialog name must be string."
_set_request_json(monkeypatch, module, {"name": " ", "prompt_config": {"system": "", "parameters": []}})
res = _run(handler())
assert res["message"] == "Dialog name can't be empty."
_set_request_json(monkeypatch, module, {"name": "a" * 256, "prompt_config": {"system": "", "parameters": []}})
res = _run(handler())
assert res["message"] == "Dialog name length is 256 which is larger than 255"
captured = {}
def _dup_name(checker, **kwargs):
assert checker(name=kwargs["name"]) is True
return kwargs["name"] + " (1)"
monkeypatch.setattr(module, "duplicate_name", _dup_name)
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(name="new dialog")])
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x")))
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="embd-a@builtin")])
monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda embd_id: embd_id.split("@"))
monkeypatch.setattr(module.DialogService, "save", lambda **kwargs: captured.update(kwargs) or False)
_set_request_json(
monkeypatch,
module,
{
"name": "New Dialog",
"kb_ids": ["kb-1"],
"prompt_config": {"system": "Use {knowledge}", "parameters": []},
},
)
res = _run(handler())
assert res["message"] == "Fail to new a dialog!"
assert captured["name"] == "New Dialog (1)"
assert captured["prompt_config"]["parameters"] == [{"key": "knowledge", "optional": False}]
_set_request_json(
monkeypatch,
module,
{
"dialog_id": "dialog-1",
"name": "Update",
"kb_ids": [],
"prompt_config": {
"system": "Use {knowledge}",
"parameters": [{"key": "knowledge", "optional": True}],
},
},
)
res = _run(handler())
assert "Please remove `{knowledge}` in system prompt" in res["message"]
_set_request_json(
monkeypatch,
module,
{"name": "demo", "prompt_config": {"system": "hello", "parameters": [{"key": "must", "optional": False}]}},
)
res = _run(handler())
assert "Parameter 'must' is not used" in res["message"]
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (False, None))
_set_request_json(monkeypatch, module, {"name": "demo", "prompt_config": {"system": "hello", "parameters": []}})
res = _run(handler())
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x")))
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue(
{
"name": "demo",
"kb_ids": ["kb-1", "kb-2"],
"prompt_config": {"system": "hello", "parameters": []},
}
),
)
monkeypatch.setattr(
module.KnowledgebaseService,
"get_by_ids",
lambda _ids: [SimpleNamespace(embd_id="embd-a@f1"), SimpleNamespace(embd_id="embd-b@f2")],
)
monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda embd_id: embd_id.split("@"))
res = _run(handler())
assert "Datasets use different embedding models" in res["message"]
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
monkeypatch.setattr(
module,
"get_request_json",
lambda: _AwaitableValue(
{
"name": "optional-param-dialog",
"prompt_config": {"system": "hello", "parameters": [{"key": "ignored", "optional": True}]},
}
),
)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [])
monkeypatch.setattr(module.DialogService, "save", lambda **_kwargs: False)
res = _run(handler())
assert res["message"] == "Fail to new a dialog!"
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [])
monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: False)
_set_request_json(
monkeypatch,
module,
{
"dialog_id": "dialog-1",
"kb_names": ["legacy"],
"name": "rename",
"prompt_config": {"system": "hello", "parameters": []},
},
)
res = _run(handler())
assert res["message"] == "Dialog not found!"
monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None))
_set_request_json(
monkeypatch,
module,
{
"dialog_id": "dialog-1",
"name": "rename",
"prompt_config": {"system": "hello", "parameters": []},
},
)
res = _run(handler())
assert res["message"] == "Fail to update a dialog!"
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, SimpleNamespace(to_dict=lambda: {"id": _id, "kb_ids": ["kb-1"]})))
monkeypatch.setattr(
module.KnowledgebaseService,
"get_by_id",
lambda _id: (True, SimpleNamespace(status=module.StatusEnum.VALID.value, name="KB One")),
)
_set_request_json(
monkeypatch,
module,
{
"dialog_id": "dialog-1",
"kb_names": ["legacy"],
"name": "new-name",
"prompt_config": {"system": "hello", "parameters": []},
},
)
res = _run(handler())
assert res["code"] == 0
assert res["data"]["name"] == "new-name"
assert res["data"]["kb_names"] == ["KB One"]
def _raise_tenant(_id):
raise RuntimeError("set boom")
monkeypatch.setattr(module.TenantService, "get_by_id", _raise_tenant)
_set_request_json(monkeypatch, module, {"name": "demo", "prompt_config": {"system": "hello", "parameters": []}})
res = _run(handler())
assert "set boom" in res["message"]
@pytest.mark.p2
def test_get_get_kb_names_and_list_dialogs_exception_matrix_unit(monkeypatch):
module = _load_dialog_module(monkeypatch)
get_handler = inspect.unwrap(module.get)
monkeypatch.setattr(
module.DialogService,
"get_by_id",
lambda _id: (True, SimpleNamespace(to_dict=lambda: {"id": _id, "kb_ids": ["kb-1", "kb-2"]})),
)
monkeypatch.setattr(
module.KnowledgebaseService,
"get_by_id",
lambda kid: (
(True, SimpleNamespace(status=module.StatusEnum.VALID.value, name="KB-1"))
if kid == "kb-1"
else (False, None)
),
)
_set_request_args(monkeypatch, module, {"dialog_id": "dialog-1"})
res = get_handler()
assert res["code"] == 0
assert res["data"]["kb_ids"] == ["kb-1"]
assert res["data"]["kb_names"] == ["KB-1"]
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None))
_set_request_args(monkeypatch, module, {"dialog_id": "dialog-missing"})
res = get_handler()
assert res["message"] == "Dialog not found!"
def _raise_get(_id):
raise RuntimeError("get boom")
monkeypatch.setattr(module.DialogService, "get_by_id", _raise_get)
_set_request_args(monkeypatch, module, {"dialog_id": "dialog-1"})
res = get_handler()
assert "get boom" in res["message"]
monkeypatch.setattr(
module.KnowledgebaseService,
"get_by_id",
lambda kid: (
(True, SimpleNamespace(status=module.StatusEnum.VALID.value, name=f"KB-{kid}"))
if kid.startswith("ok")
else (True, SimpleNamespace(status=module.StatusEnum.INVALID.value, name=f"BAD-{kid}"))
),
)
ids, names = module.get_kb_names(["ok-1", "bad-1", "ok-2"])
assert ids == ["ok-1", "ok-2"]
assert names == ["KB-ok-1", "KB-ok-2"]
def _raise_list(**_kwargs):
raise RuntimeError("list boom")
monkeypatch.setattr(module.DialogService, "query", _raise_list)
res = module.list_dialogs()
assert "list boom" in res["message"]
@pytest.mark.p2
def test_list_dialogs_next_owner_desc_and_pagination_matrix_unit(monkeypatch):
module = _load_dialog_module(monkeypatch)
handler = inspect.unwrap(module.list_dialogs_next)
calls = []
def _get_by_tenant_ids(tenants, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id):
calls.append(
{
"tenants": tenants,
"user_id": user_id,
"page_number": page_number,
"items_per_page": items_per_page,
"orderby": orderby,
"desc": desc,
"keywords": keywords,
"parser_id": parser_id,
}
)
if tenants:
return (
[
{"id": "dialog-1", "tenant_id": "tenant-a"},
{"id": "dialog-2", "tenant_id": "tenant-x"},
{"id": "dialog-3", "tenant_id": "tenant-b"},
],
3,
)
return ([{"id": "dialog-0", "tenant_id": "tenant-1"}], 1)
monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids)
_set_request_args(
monkeypatch,
module,
{
"keywords": "k",
"page": "1",
"page_size": "2",
"parser_id": "parser-x",
"orderby": "create_time",
"desc": "false",
},
)
_set_request_json(monkeypatch, module, {"owner_ids": []})
res = _run(handler())
assert res["code"] == 0
assert res["data"]["total"] == 1
assert calls[-1]["tenants"] == []
assert calls[-1]["desc"] is False
_set_request_args(monkeypatch, module, {"page": "2", "page_size": "1"})
_set_request_json(monkeypatch, module, {"owner_ids": ["tenant-a", "tenant-b"]})
res = _run(handler())
assert res["code"] == 0
assert res["data"]["total"] == 2
assert res["data"]["dialogs"] == [{"id": "dialog-3", "tenant_id": "tenant-b"}]
assert calls[-1]["page_number"] == 0
assert calls[-1]["items_per_page"] == 0
assert calls[-1]["desc"] is True
def _raise_next(*_args, **_kwargs):
raise RuntimeError("next boom")
monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _raise_next)
_set_request_args(monkeypatch, module, {"page": "1", "page_size": "1"})
_set_request_json(monkeypatch, module, {"owner_ids": []})
res = _run(handler())
assert "next boom" in res["message"]
@pytest.mark.p2
def test_rm_permission_and_exception_matrix_unit(monkeypatch):
module = _load_dialog_module(monkeypatch)
handler = inspect.unwrap(module.rm)
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-a")])
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
_set_request_json(monkeypatch, module, {"dialog_ids": ["dialog-1"]})
res = _run(handler())
assert res["code"] == module.RetCode.OPERATING_ERROR
assert "Only owner of dialog authorized for this operation." in res["message"]
def _raise_query(**_kwargs):
raise RuntimeError("rm boom")
monkeypatch.setattr(module.DialogService, "query", _raise_query)
_set_request_json(monkeypatch, module, {"dialog_ids": ["dialog-1"]})
res = _run(handler())
assert "rm boom" in res["message"]

View File

@ -250,6 +250,31 @@ def _run(coro):
return asyncio.run(coro)
class _DummyArgs:
def __init__(self, args=None):
self._args = args or {}
def get(self, key, default=None):
return self._args.get(key, default)
def getlist(self, key):
value = self._args.get(key, [])
if isinstance(value, list):
return value
return [value]
class _DummyRequest:
def __init__(self, args=None):
self.args = _DummyArgs(args)
class _DummyResponse:
def __init__(self, data=None):
self.data = data
self.headers = {}
@pytest.mark.p2
class TestDocumentMetadataUnit:
def _allow_kb(self, module, monkeypatch, kb_id="kb1", tenant_id="tenant1"):
@ -411,3 +436,546 @@ class TestDocumentMetadataUnit:
res = _run(module.metadata_update.__wrapped__())
assert res["code"] == 0
assert res["data"]["matched_docs"] == 1
def test_metadata_update_invalid_delete_item_unit(self, document_app_module, monkeypatch):
module = document_app_module
async def fake_request_json():
return {"kb_id": "kb1", "doc_ids": ["doc1"], "updates": [], "deletes": [{}]}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.metadata_update.__wrapped__())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert "Each delete requires key." in res["message"]
def test_update_metadata_setting_authorization_and_refetch_not_found_unit(self, document_app_module, monkeypatch):
module = document_app_module
async def fake_request_json():
return {"doc_id": "doc1", "metadata": {"author": "alice"}}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False)
res = _run(module.update_metadata_setting.__wrapped__())
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
assert "No authorization." in res["message"]
doc = SimpleNamespace(id="doc1", to_dict=lambda: {"id": "doc1", "parser_config": {}})
state = {"count": 0}
def fake_get_by_id(_doc_id):
state["count"] += 1
if state["count"] == 1:
return True, doc
return False, None
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.DocumentService, "get_by_id", fake_get_by_id)
monkeypatch.setattr(module.DocumentService, "update_parser_config", lambda *_args, **_kwargs: True)
res = _run(module.update_metadata_setting.__wrapped__())
assert res["code"] == module.RetCode.DATA_ERROR
assert "Document not found!" in res["message"]
def test_thumbnails_missing_ids_rewrite_and_exception_unit(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module, "request", _DummyRequest(args={}))
res = module.thumbnails()
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert 'Lack of "Document ID"' in res["message"]
monkeypatch.setattr(module, "request", _DummyRequest(args={"doc_ids": ["doc1", "doc2"]}))
monkeypatch.setattr(
module.DocumentService,
"get_thumbnails",
lambda _doc_ids: [
{"id": "doc1", "kb_id": "kb1", "thumbnail": "thumb.jpg"},
{"id": "doc2", "kb_id": "kb1", "thumbnail": f"{module.IMG_BASE64_PREFIX}blob"},
],
)
res = module.thumbnails()
assert res["code"] == 0
assert res["data"]["doc1"] == "/v1/document/image/kb1-thumb.jpg"
assert res["data"]["doc2"] == f"{module.IMG_BASE64_PREFIX}blob"
def raise_error(*_args, **_kwargs):
raise RuntimeError("thumb boom")
monkeypatch.setattr(module.DocumentService, "get_thumbnails", raise_error)
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
res = module.thumbnails()
assert res["code"] == 500
assert "thumb boom" in res["message"]
def test_change_status_partial_failure_matrix_unit(self, document_app_module, monkeypatch):
module = document_app_module
calls = {"docstore_update": []}
doc_ids = ["unauth", "missing_doc", "missing_kb", "update_fail", "docstore_3022", "docstore_generic", "outer_exc"]
async def fake_request_json():
return {"doc_ids": doc_ids, "status": "1"}
def fake_accessible(doc_id, _uid):
return doc_id != "unauth"
def fake_get_by_id(doc_id):
if doc_id == "missing_doc":
return False, None
if doc_id == "outer_exc":
raise RuntimeError("explode")
kb_id = "kb_missing" if doc_id == "missing_kb" else "kb1"
chunk_num = 1 if doc_id in {"docstore_3022", "docstore_generic"} else 0
doc = SimpleNamespace(id=doc_id, kb_id=kb_id, status="0", chunk_num=chunk_num)
return True, doc
def fake_get_kb(kb_id):
if kb_id == "kb_missing":
return False, None
return True, SimpleNamespace(tenant_id="tenant1")
def fake_update_by_id(doc_id, _payload):
return doc_id != "update_fail"
class _DocStore:
def update(self, where, _payload, _index_name, _kb_id):
calls["docstore_update"].append(where["doc_id"])
if where["doc_id"] == "docstore_3022":
raise RuntimeError("3022 table missing")
if where["doc_id"] == "docstore_generic":
raise RuntimeError("doc store down")
return True
monkeypatch.setattr(module, "get_request_json", fake_request_json)
monkeypatch.setattr(module.DocumentService, "accessible", fake_accessible)
monkeypatch.setattr(module.DocumentService, "get_by_id", fake_get_by_id)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda kb_id: fake_get_kb(kb_id))
monkeypatch.setattr(module.DocumentService, "update_by_id", fake_update_by_id)
monkeypatch.setattr(module.settings, "docStoreConn", _DocStore())
monkeypatch.setattr(module.search, "index_name", lambda tenant_id: f"idx_{tenant_id}")
res = _run(module.change_status.__wrapped__())
assert res["code"] == module.RetCode.SERVER_ERROR
assert res["message"] == "Partial failure"
assert res["data"]["unauth"]["error"] == "No authorization."
assert res["data"]["missing_doc"]["error"] == "No authorization."
assert res["data"]["missing_kb"]["error"] == "Can't find this dataset!"
assert res["data"]["update_fail"]["error"] == "Database error (Document update)!"
assert res["data"]["docstore_3022"]["error"] == "Document store table missing."
assert "Document store update failed:" in res["data"]["docstore_generic"]["error"]
assert "Internal server error: explode" == res["data"]["outer_exc"]["error"]
assert calls["docstore_update"] == ["docstore_3022", "docstore_generic"]
def test_change_status_invalid_status_unit(self, document_app_module, monkeypatch):
module = document_app_module
async def fake_request_json():
return {"doc_ids": ["doc1"], "status": "2"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.change_status.__wrapped__())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert '"Status" must be either 0 or 1!' in res["message"]
def test_change_status_all_success_unit(self, document_app_module, monkeypatch):
module = document_app_module
async def fake_request_json():
return {"doc_ids": ["doc1"], "status": "1"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id="doc1", kb_id="kb1", status="0", chunk_num=0)))
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant1")))
monkeypatch.setattr(module.DocumentService, "update_by_id", lambda *_args, **_kwargs: True)
res = _run(module.change_status.__wrapped__())
assert res["code"] == 0
assert res["data"]["doc1"]["status"] == "1"
def test_rename_branch_matrix_and_exception_unit(self, document_app_module, monkeypatch):
module = document_app_module
file_updates = []
es_updates = []
async def fake_thread_pool_exec(func, *_args, **_kwargs):
return func()
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant1")
monkeypatch.setattr(module.rag_tokenizer, "tokenize", lambda _name: ["token"])
monkeypatch.setattr(module.rag_tokenizer, "fine_grained_tokenize", lambda _tokens: ["fine"])
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
class _DocStore:
def index_exist(self, _index_name, _kb_id):
return True
def update(self, where, payload, _index_name, _kb_id):
es_updates.append((where, payload))
monkeypatch.setattr(module.settings, "docStoreConn", _DocStore())
monkeypatch.setattr(module.search, "index_name", lambda tenant_id: f"idx_{tenant_id}")
def set_req(name):
async def fake_request_json():
return {"doc_id": "doc1", "name": name}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
set_req("renamed.txt")
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False)
res = _run(module.rename.__wrapped__())
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
res = _run(module.rename.__wrapped__())
assert res["code"] == module.RetCode.DATA_ERROR
assert "Document not found!" in res["message"]
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id="doc1", name="origin.txt", kb_id="kb1")))
set_req("renamed.pdf")
res = _run(module.rename.__wrapped__())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert "extension" in res["message"]
too_long = "a" * (module.FILE_NAME_LEN_LIMIT + 1) + ".txt"
set_req(too_long)
res = _run(module.rename.__wrapped__())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert "bytes or less" in res["message"]
set_req("dup.txt")
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [SimpleNamespace(name="dup.txt")])
res = _run(module.rename.__wrapped__())
assert res["code"] == module.RetCode.DATA_ERROR
assert "Duplicated document name" in res["message"]
set_req("ok.txt")
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module.DocumentService, "update_by_id", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.File2DocumentService, "get_by_document_id", lambda _doc_id: [SimpleNamespace(file_id="file1")])
monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (True, SimpleNamespace(id="file1")))
monkeypatch.setattr(module.FileService, "update_by_id", lambda file_id, payload: file_updates.append((file_id, payload)))
res = _run(module.rename.__wrapped__())
assert res["code"] == 0
assert file_updates == [("file1", {"name": "ok.txt"})]
assert es_updates[0][0] == {"doc_id": "doc1"}
assert es_updates[0][1]["docnm_kwd"] == "ok.txt"
assert es_updates[0][1]["title_tks"] == ["token"]
assert es_updates[0][1]["title_sm_tks"] == ["fine"]
def raise_db_error(*_args, **_kwargs):
raise RuntimeError("rename boom")
monkeypatch.setattr(module.DocumentService, "update_by_id", raise_db_error)
res = _run(module.rename.__wrapped__())
assert res["code"] == 500
assert "rename boom" in res["message"]
def test_get_route_not_found_success_and_exception_unit(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
res = _run(module.get("doc1"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "Document not found!" in res["message"]
async def fake_thread_pool_exec(*_args, **_kwargs):
return b"blob-data"
async def fake_make_response(data):
return _DummyResponse(data)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(name="image.abc", type=module.FileType.VISUAL.value)))
monkeypatch.setattr(module.File2DocumentService, "get_storage_address", lambda **_kwargs: ("bucket", "name"))
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"blob-data"))
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
monkeypatch.setattr(module, "make_response", fake_make_response)
monkeypatch.setattr(
module,
"apply_safe_file_response_headers",
lambda response, content_type, extension: response.headers.update({"content_type": content_type, "extension": extension}),
)
res = _run(module.get("doc1"))
assert isinstance(res, _DummyResponse)
assert res.data == b"blob-data"
assert res.headers["content_type"] == "image/abc"
assert res.headers["extension"] == "abc"
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (_ for _ in ()).throw(RuntimeError("get boom")))
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
res = _run(module.get("doc1"))
assert res["code"] == 500
assert "get boom" in res["message"]
def test_download_attachment_success_and_exception_unit(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module, "request", _DummyRequest(args={"ext": "abc"}))
async def fake_thread_pool_exec(*_args, **_kwargs):
return b"attachment"
async def fake_make_response(data):
return _DummyResponse(data)
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
monkeypatch.setattr(module, "make_response", fake_make_response)
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"attachment"))
monkeypatch.setattr(
module,
"apply_safe_file_response_headers",
lambda response, content_type, extension: response.headers.update({"content_type": content_type, "extension": extension}),
)
res = _run(module.download_attachment("att1"))
assert isinstance(res, _DummyResponse)
assert res.data == b"attachment"
assert res.headers["content_type"] == "application/abc"
assert res.headers["extension"] == "abc"
async def raise_error(*_args, **_kwargs):
raise RuntimeError("download boom")
monkeypatch.setattr(module, "thread_pool_exec", raise_error)
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
res = _run(module.download_attachment("att1"))
assert res["code"] == 500
assert "download boom" in res["message"]
def test_change_parser_guards_and_reset_update_failure_unit(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
async def req_auth_fail():
return {"doc_id": "doc1", "parser_id": "naive", "pipeline_id": "pipe2"}
monkeypatch.setattr(module, "get_request_json", req_auth_fail)
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False)
res = _run(module.change_parser.__wrapped__())
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
res = _run(module.change_parser.__wrapped__())
assert res["code"] == module.RetCode.DATA_ERROR
assert "Document not found!" in res["message"]
async def req_same_pipeline():
return {"doc_id": "doc1", "parser_id": "naive", "pipeline_id": "pipe1"}
doc_same = SimpleNamespace(
id="doc1",
pipeline_id="pipe1",
parser_id="naive",
parser_config={"k": "v"},
token_num=0,
chunk_num=0,
process_duration=0,
kb_id="kb1",
type="doc",
name="doc.txt",
)
monkeypatch.setattr(module, "get_request_json", req_same_pipeline)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_same))
res = _run(module.change_parser.__wrapped__())
assert res["code"] == 0
calls = []
async def req_pipeline_change():
return {"doc_id": "doc1", "parser_id": "naive", "pipeline_id": "pipe2"}
doc = SimpleNamespace(
id="doc1",
pipeline_id="pipe1",
parser_id="naive",
parser_config={},
token_num=0,
chunk_num=0,
process_duration=0,
kb_id="kb1",
type="doc",
name="doc.txt",
)
def fake_update_by_id(doc_id, payload):
calls.append((doc_id, payload))
return True
monkeypatch.setattr(module, "get_request_json", req_pipeline_change)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc))
monkeypatch.setattr(module.DocumentService, "update_by_id", fake_update_by_id)
res = _run(module.change_parser.__wrapped__())
assert res["code"] == 0
assert calls[0][1] == {"pipeline_id": "pipe2"}
assert calls[1][1]["run"] == module.TaskStatus.UNSTART.value
doc.token_num = 3
doc.chunk_num = 2
doc.process_duration = 9
monkeypatch.setattr(module.DocumentService, "increment_chunk_num", lambda *_args, **_kwargs: False)
res = _run(module.change_parser.__wrapped__())
assert res["code"] == 0
monkeypatch.setattr(module.DocumentService, "increment_chunk_num", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: None)
res = _run(module.change_parser.__wrapped__())
assert res["code"] == 0
side_effects = {"img": [], "delete": []}
class _DocStore:
def index_exist(self, _idx, _kb_id):
return True
def delete(self, where, _idx, kb_id):
side_effects["delete"].append((where["doc_id"], kb_id))
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant1")
monkeypatch.setattr(module.DocumentService, "delete_chunk_images", lambda _doc, _tenant: side_effects["img"].append((_doc.id, _tenant)))
monkeypatch.setattr(module.search, "index_name", lambda tenant_id: f"idx_{tenant_id}")
monkeypatch.setattr(module.settings, "docStoreConn", _DocStore())
res = _run(module.change_parser.__wrapped__())
assert res["code"] == 0
assert ("doc1", "tenant1") in side_effects["img"]
assert ("doc1", "kb1") in side_effects["delete"]
async def req_same_parser_with_cfg():
return {"doc_id": "doc1", "parser_id": "naive", "parser_config": {"a": 1}}
doc_same_parser = SimpleNamespace(
id="doc1",
pipeline_id="pipe1",
parser_id="naive",
parser_config={"a": 1},
token_num=0,
chunk_num=0,
process_duration=0,
kb_id="kb1",
type="doc",
name="doc.txt",
)
monkeypatch.setattr(module, "get_request_json", req_same_parser_with_cfg)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_same_parser))
res = _run(module.change_parser.__wrapped__())
assert res["code"] == 0
async def req_same_parser_no_cfg():
return {"doc_id": "doc1", "parser_id": "naive"}
monkeypatch.setattr(module, "get_request_json", req_same_parser_no_cfg)
res = _run(module.change_parser.__wrapped__())
assert res["code"] == 0
parser_cfg_updates = []
async def req_parser_update():
return {"doc_id": "doc1", "parser_id": "paper", "pipeline_id": "", "parser_config": {"beta": True}}
doc_parser_update = SimpleNamespace(
id="doc1",
pipeline_id="pipe1",
parser_id="naive",
parser_config={"alpha": 1},
token_num=0,
chunk_num=0,
process_duration=0,
kb_id="kb1",
type="doc",
name="doc.txt",
)
monkeypatch.setattr(module, "get_request_json", req_parser_update)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_parser_update))
monkeypatch.setattr(module.DocumentService, "update_parser_config", lambda doc_id, cfg: parser_cfg_updates.append((doc_id, cfg)))
monkeypatch.setattr(module.DocumentService, "update_by_id", lambda *_args, **_kwargs: True)
res = _run(module.change_parser.__wrapped__())
assert res["code"] == 0
assert parser_cfg_updates == [("doc1", {"beta": True})]
def raise_parser_config(*_args, **_kwargs):
raise RuntimeError("parser boom")
monkeypatch.setattr(module.DocumentService, "update_parser_config", raise_parser_config)
res = _run(module.change_parser.__wrapped__())
assert res["code"] == 500
assert "parser boom" in res["message"]
def test_get_image_success_and_exception_unit(self, document_app_module, monkeypatch):
module = document_app_module
class _Headers(dict):
def set(self, key, value):
self[key] = value
class _ImageResponse:
def __init__(self, data):
self.data = data
self.headers = _Headers()
async def fake_thread_pool_exec(*_args, **_kwargs):
return b"image-bytes"
async def fake_make_response(data):
return _ImageResponse(data)
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
monkeypatch.setattr(module, "make_response", fake_make_response)
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"image-bytes"))
res = _run(module.get_image("bucket-name"))
assert isinstance(res, _ImageResponse)
assert res.data == b"image-bytes"
assert res.headers["Content-Type"] == "image/JPEG"
async def raise_error(*_args, **_kwargs):
raise RuntimeError("image boom")
monkeypatch.setattr(module, "thread_pool_exec", raise_error)
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
res = _run(module.get_image("bucket-name"))
assert res["code"] == 500
assert "image boom" in res["message"]
def test_set_meta_validation_and_persistence_matrix_unit(self, document_app_module, monkeypatch):
module = document_app_module
def set_req(payload):
async def fake_request_json():
return payload
monkeypatch.setattr(module, "get_request_json", fake_request_json)
set_req({"doc_id": "doc1", "meta": "{}"})
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False)
res = _run(module.set_meta.__wrapped__())
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
set_req({"doc_id": "doc1", "meta": "[]"})
res = _run(module.set_meta.__wrapped__())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert "Only dictionary type supported." in res["message"]
set_req({"doc_id": "doc1", "meta": '{"tags":[{"x":1}]}'})
res = _run(module.set_meta.__wrapped__())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert "The type is not supported in list" in res["message"]
set_req({"doc_id": "doc1", "meta": '{"obj":{"x":1}}'})
res = _run(module.set_meta.__wrapped__())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert "The type is not supported" in res["message"]
set_req({"doc_id": "doc1", "meta": "{"})
res = _run(module.set_meta.__wrapped__())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert "Json syntax error:" in res["message"]
set_req({"doc_id": "doc1", "meta": '{"author":"alice"}'})
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
res = _run(module.set_meta.__wrapped__())
assert res["code"] == module.RetCode.DATA_ERROR
assert "Document not found!" in res["message"]
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id="doc1")))
monkeypatch.setattr(module.DocMetadataService, "update_document_metadata", lambda *_args, **_kwargs: False)
res = _run(module.set_meta.__wrapped__())
assert res["code"] == module.RetCode.DATA_ERROR
assert "Database error (meta updates)!" in res["message"]

View File

@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from concurrent.futures import ThreadPoolExecutor, as_completed
from types import SimpleNamespace
import pytest
from common import bulk_upload_documents, list_documents, parse_documents
@ -22,6 +24,10 @@ from libs.auth import RAGFlowWebApiAuth
from utils import wait_for
def _run(coro):
return asyncio.run(coro)
@wait_for(30, 1, "Document parsing timeout")
def condition(_auth, _kb_id, _document_ids=None):
res = list_documents(_auth, {"kb_id": _kb_id})
@ -194,6 +200,94 @@ def test_concurrent_parse(WebApiAuth, add_dataset_func, tmp_path):
validate_document_parse_done(WebApiAuth, kb_id, document_ids)
@pytest.mark.p2
class TestDocumentsParseUnit:
def test_run_branch_matrix_unit(self, document_app_module, monkeypatch):
module = document_app_module
calls = {"clear": [], "filter_delete": [], "docstore_delete": [], "cancel": [], "run": []}
async def fake_thread_pool_exec(func, *args, **kwargs):
return func(*args, **kwargs)
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
monkeypatch.setattr(module.search, "index_name", lambda tenant_id: f"idx_{tenant_id}")
monkeypatch.setattr(module, "cancel_all_task_of", lambda doc_id: calls["cancel"].append(doc_id))
class _DocStore:
def index_exist(self, _index_name, _kb_id):
return True
def delete(self, where, _index_name, _kb_id):
calls["docstore_delete"].append(where["doc_id"])
monkeypatch.setattr(module.settings, "docStoreConn", _DocStore())
async def set_request(payload):
return payload
def apply_request(payload):
async def fake_request_json():
return await set_request(payload)
monkeypatch.setattr(module, "get_request_json", fake_request_json)
apply_request({"doc_ids": ["doc1"], "run": module.TaskStatus.RUNNING.value})
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False)
res = _run(module.run.__wrapped__())
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: None)
res = _run(module.run.__wrapped__())
assert res["code"] == module.RetCode.DATA_ERROR
assert "Tenant not found!" in res["message"]
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant1")
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
res = _run(module.run.__wrapped__())
assert res["code"] == module.RetCode.DATA_ERROR
assert "Document not found!" in res["message"]
apply_request({"doc_ids": ["doc1"], "run": module.TaskStatus.CANCEL.value})
doc_cancel = SimpleNamespace(id="doc1", run=module.TaskStatus.DONE.value, kb_id="kb1", parser_config={}, to_dict=lambda: {"id": "doc1"})
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_cancel))
monkeypatch.setattr(module.TaskService, "query", lambda **_kwargs: [SimpleNamespace(progress=1)])
res = _run(module.run.__wrapped__())
assert res["code"] == module.RetCode.DATA_ERROR
assert "Cannot cancel a task that is not in RUNNING status" in res["message"]
apply_request({"doc_ids": ["doc1"], "run": module.TaskStatus.RUNNING.value, "delete": True})
doc_rerun = SimpleNamespace(id="doc1", run=module.TaskStatus.DONE.value, kb_id="kb1", parser_config={}, to_dict=lambda: {"id": "doc1"})
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_rerun))
monkeypatch.setattr(module.DocumentService, "clear_chunk_num_when_rerun", lambda doc_id: calls["clear"].append(doc_id))
monkeypatch.setattr(module.TaskService, "filter_delete", lambda _filters: calls["filter_delete"].append(True))
monkeypatch.setattr(module.DocumentService, "update_by_id", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.DocumentService, "run", lambda tenant_id, doc_dict, _kb_map: calls["run"].append((tenant_id, doc_dict)))
res = _run(module.run.__wrapped__())
assert res["code"] == 0
assert calls["clear"] == ["doc1"]
assert calls["filter_delete"] == [True]
assert calls["docstore_delete"] == ["doc1"]
assert calls["run"] == [("tenant1", {"id": "doc1"})]
apply_request({"doc_ids": ["doc1"], "run": module.TaskStatus.RUNNING.value, "apply_kb": True})
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
res = _run(module.run.__wrapped__())
assert res["code"] == 500
assert "Can't find this dataset!" in res["message"]
apply_request({"doc_ids": ["doc1"], "run": module.TaskStatus.RUNNING.value})
def raise_run_error(*_args, **_kwargs):
raise RuntimeError("run boom")
monkeypatch.setattr(module.DocumentService, "run", raise_run_error)
res = _run(module.run.__wrapped__())
assert res["code"] == 500
assert "run boom" in res["message"]
# @pytest.mark.skip
class TestDocumentsParseStop:
@pytest.mark.parametrize(

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
@ -21,6 +22,10 @@ from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth
def _run(coro):
return asyncio.run(coro)
@pytest.mark.p2
class TestAuthorization:
@pytest.mark.parametrize(
@ -75,6 +80,32 @@ class TestDocumentsDeletion:
assert res["message"] == "No authorization.", res
@pytest.mark.p2
class TestDocumentsDeletionUnit:
def test_rm_string_doc_id_normalization_success_unit(self, document_app_module, monkeypatch):
module = document_app_module
captured = {}
async def fake_request_json():
return {"doc_id": "doc1"}
async def fake_thread_pool_exec(func, doc_ids, user_id):
captured["func"] = func
captured["doc_ids"] = doc_ids
captured["user_id"] = user_id
return None
monkeypatch.setattr(module, "get_request_json", fake_request_json)
monkeypatch.setattr(module.DocumentService, "accessible4deletion", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
res = _run(module.rm.__wrapped__())
assert res["code"] == 0
assert res["data"] is True
assert captured["func"] == module.FileService.delete_docs
assert captured["doc_ids"] == ["doc1"]
assert captured["user_id"] == module.current_user.id
@pytest.mark.p3
def test_concurrent_deletion(WebApiAuth, add_dataset, tmp_path):
count = 100

View File

@ -14,8 +14,9 @@
# limitations under the License.
#
import asyncio
import sys
import string
from types import SimpleNamespace
from types import ModuleType, SimpleNamespace
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
@ -333,6 +334,136 @@ class TestDocumentsUploadUnit:
assert res["code"] == 102
assert "file format" in res["message"]
def test_upload_and_parse_matrix_unit(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module, "request", _DummyRequest(form={"conversation_id": "conv-1"}, files=_DummyFiles({"file": [_DummyFile("")]})))
res = _run(module.upload_and_parse.__wrapped__())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert res["message"] == "No file selected!"
files = _DummyFiles({"file": [_DummyFile("note.txt")]})
monkeypatch.setattr(module, "request", _DummyRequest(form={"conversation_id": "conv-1"}, files=files))
monkeypatch.setattr(module, "doc_upload_and_parse", lambda _conv_id, _files, _uid: ["doc-1"])
res = _run(module.upload_and_parse.__wrapped__())
assert res["code"] == 0
assert res["data"] == ["doc-1"]
def test_parse_url_and_multipart_matrix_unit(self, document_app_module, monkeypatch, tmp_path):
module = document_app_module
async def req_invalid_url():
return {"url": "not-a-url"}
monkeypatch.setattr(module, "get_request_json", req_invalid_url)
monkeypatch.setattr(module, "is_valid_url", lambda _url: False)
res = _run(module.parse())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert res["message"] == "The URL format is invalid"
webdriver_mod = ModuleType("seleniumwire.webdriver")
class _FakeChromeOptions:
def __init__(self):
self.args = []
self.experimental = {}
def add_argument(self, arg):
self.args.append(arg)
def add_experimental_option(self, key, value):
self.experimental[key] = value
class _Req:
def __init__(self, headers):
self.response = SimpleNamespace(headers=headers)
class _FakeDriver:
def __init__(self, requests, page_source):
self.requests = requests
self.page_source = page_source
self.quit_called = False
self.visited = []
self.options = None
def get(self, url):
self.visited.append(url)
def quit(self):
self.quit_called = True
queue = []
created = []
def _fake_chrome(options=None):
driver = queue.pop(0)
driver.options = options
created.append(driver)
return driver
webdriver_mod.Chrome = _fake_chrome
webdriver_mod.ChromeOptions = _FakeChromeOptions
seleniumwire_mod = ModuleType("seleniumwire")
seleniumwire_mod.webdriver = webdriver_mod
monkeypatch.setitem(sys.modules, "seleniumwire", seleniumwire_mod)
monkeypatch.setitem(sys.modules, "seleniumwire.webdriver", webdriver_mod)
monkeypatch.setattr(module, "get_project_base_directory", lambda: str(tmp_path))
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
class _Parser:
def parser_txt(self, page_source):
assert "page" in page_source
return ["section1", "section2"]
monkeypatch.setattr(module, "RAGFlowHtmlParser", lambda: _Parser())
queue.append(_FakeDriver([_Req({"x": "1"}), _Req({"y": "2"})], "<html>page</html>"))
async def req_url_html():
return {"url": "http://example.com/html"}
monkeypatch.setattr(module, "get_request_json", req_url_html)
res = _run(module.parse())
assert res["code"] == 0
assert res["data"] == "section1\nsection2"
assert created[-1].quit_called is True
(tmp_path / "logs" / "downloads").mkdir(parents=True, exist_ok=True)
(tmp_path / "logs" / "downloads" / "doc.txt").write_bytes(b"downloaded-bytes")
queue.append(_FakeDriver([_Req({"content-disposition": 'attachment; filename="doc.txt"'})], "<html>file</html>"))
captured = {}
def parse_docs_read(files, _uid):
captured["filename"] = files[0].filename
captured["content"] = files[0].read()
return "parsed-download"
monkeypatch.setattr(module.FileService, "parse_docs", parse_docs_read)
async def req_url_file():
return {"url": "http://example.com/file"}
monkeypatch.setattr(module, "get_request_json", req_url_file)
res = _run(module.parse())
assert res["code"] == 0
assert res["data"] == "parsed-download"
assert captured["filename"] == "doc.txt"
assert captured["content"] == b"downloaded-bytes"
async def req_no_url():
return {}
monkeypatch.setattr(module, "get_request_json", req_no_url)
monkeypatch.setattr(module, "request", _DummyRequest(files=_DummyFiles()))
res = _run(module.parse())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert res["message"] == "No file part!"
monkeypatch.setattr(module, "request", _DummyRequest(files=_DummyFiles({"file": [_DummyFile("f1.txt")]})))
monkeypatch.setattr(module.FileService, "parse_docs", lambda _files, _uid: "parsed-upload")
res = _run(module.parse())
assert res["code"] == 0
assert res["data"] == "parsed-upload"
@pytest.mark.p2
class TestWebCrawlUnit:

View File

@ -0,0 +1,575 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _Args(dict):
def get(self, key, default=None):
return super().get(key, default)
class _DummyRetCode:
SUCCESS = 0
EXCEPTION_ERROR = 100
ARGUMENT_ERROR = 101
DATA_ERROR = 102
OPERATING_ERROR = 103
AUTHENTICATION_ERROR = 109
def _run(coro):
return asyncio.run(coro)
def _set_request_json(monkeypatch, module, payload):
async def _request_json():
return payload
monkeypatch.setattr(module, "get_request_json", _request_json)
def _set_request_args(monkeypatch, module, args=None):
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args or {})))
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_evaluation_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
quart_mod = ModuleType("quart")
quart_mod.request = SimpleNamespace(args=_Args())
monkeypatch.setitem(sys.modules, "quart", quart_mod)
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
constants_mod = ModuleType("common.constants")
constants_mod.RetCode = _DummyRetCode
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
common_pkg.constants = constants_mod
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="tenant-1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
api_pkg.apps = apps_mod
db_pkg = ModuleType("api.db")
db_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
api_pkg.db = db_pkg
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
evaluation_service_mod = ModuleType("api.db.services.evaluation_service")
class _EvaluationService:
@staticmethod
def create_dataset(**_kwargs):
return True, "dataset-1"
@staticmethod
def list_datasets(**_kwargs):
return {"datasets": [], "total": 0}
@staticmethod
def get_dataset(_dataset_id):
return {"id": _dataset_id}
@staticmethod
def update_dataset(_dataset_id, **_kwargs):
return True
@staticmethod
def delete_dataset(_dataset_id):
return True
@staticmethod
def add_test_case(**_kwargs):
return True, "case-1"
@staticmethod
def import_test_cases(**_kwargs):
return 0, 0
@staticmethod
def get_test_cases(_dataset_id):
return []
@staticmethod
def delete_test_case(_case_id):
return True
@staticmethod
def start_evaluation(**_kwargs):
return True, "run-1"
@staticmethod
def get_run_results(_run_id):
return {"id": _run_id}
@staticmethod
def get_recommendations(_run_id):
return []
evaluation_service_mod.EvaluationService = _EvaluationService
monkeypatch.setitem(sys.modules, "api.db.services.evaluation_service", evaluation_service_mod)
utils_pkg = ModuleType("api.utils")
utils_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.utils", utils_pkg)
api_utils_mod = ModuleType("api.utils.api_utils")
async def _default_request_json():
return {}
def _get_data_error_result(code=_DummyRetCode.DATA_ERROR, message="Sorry! Data missing!"):
return {"code": code, "message": message}
def _get_json_result(code=_DummyRetCode.SUCCESS, message="success", data=None):
return {"code": code, "message": message, "data": data}
def _server_error_response(error):
return {"code": _DummyRetCode.EXCEPTION_ERROR, "message": repr(error)}
def _validate_request(*_args, **_kwargs):
def _decorator(func):
return func
return _decorator
api_utils_mod.get_data_error_result = _get_data_error_result
api_utils_mod.get_json_result = _get_json_result
api_utils_mod.get_request_json = _default_request_json
api_utils_mod.server_error_response = _server_error_response
api_utils_mod.validate_request = _validate_request
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
utils_pkg.api_utils = api_utils_mod
module_name = "test_evaluation_routes_unit_module"
module_path = repo_root / "api" / "apps" / "evaluation_app.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_dataset_routes_matrix_unit(monkeypatch):
module = _load_evaluation_app(monkeypatch)
_set_request_json(monkeypatch, module, {"name": " data-1 ", "description": "desc", "kb_ids": ["kb-1"]})
monkeypatch.setattr(module.EvaluationService, "create_dataset", lambda **_kwargs: (True, "dataset-ok"))
res = _run(module.create_dataset())
assert res["code"] == 0
assert res["data"]["dataset_id"] == "dataset-ok"
_set_request_json(monkeypatch, module, {"name": " ", "kb_ids": ["kb-1"]})
res = _run(module.create_dataset())
assert res["code"] == module.RetCode.DATA_ERROR
assert "empty" in res["message"].lower()
_set_request_json(monkeypatch, module, {"name": "data-2", "kb_ids": "kb-1"})
res = _run(module.create_dataset())
assert res["code"] == module.RetCode.DATA_ERROR
assert "kb_ids" in res["message"]
_set_request_json(monkeypatch, module, {"name": "data-3", "kb_ids": ["kb-1"]})
monkeypatch.setattr(module.EvaluationService, "create_dataset", lambda **_kwargs: (False, "create failed"))
res = _run(module.create_dataset())
assert res["code"] == module.RetCode.DATA_ERROR
assert res["message"] == "create failed"
def _raise_create(**_kwargs):
raise RuntimeError("create boom")
monkeypatch.setattr(module.EvaluationService, "create_dataset", _raise_create)
res = _run(module.create_dataset())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "create boom" in res["message"]
_set_request_args(monkeypatch, module, {"page": "2", "page_size": "3"})
monkeypatch.setattr(module.EvaluationService, "list_datasets", lambda **_kwargs: {"datasets": [{"id": "a"}], "total": 1})
res = _run(module.list_datasets())
assert res["code"] == 0
assert res["data"]["total"] == 1
_set_request_args(monkeypatch, module, {"page": "x"})
res = _run(module.list_datasets())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
monkeypatch.setattr(module.EvaluationService, "get_dataset", lambda _dataset_id: None)
res = _run(module.get_dataset("dataset-1"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "not found" in res["message"].lower()
monkeypatch.setattr(module.EvaluationService, "get_dataset", lambda _dataset_id: {"id": _dataset_id})
res = _run(module.get_dataset("dataset-2"))
assert res["code"] == 0
assert res["data"]["id"] == "dataset-2"
def _raise_get(_dataset_id):
raise RuntimeError("get dataset boom")
monkeypatch.setattr(module.EvaluationService, "get_dataset", _raise_get)
res = _run(module.get_dataset("dataset-3"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "get dataset boom" in res["message"]
captured = {}
def _update(dataset_id, **kwargs):
captured["dataset_id"] = dataset_id
captured["kwargs"] = kwargs
return True
_set_request_json(
monkeypatch,
module,
{
"id": "forbidden",
"tenant_id": "forbidden",
"created_by": "forbidden",
"create_time": 123,
"name": "new-name",
},
)
monkeypatch.setattr(module.EvaluationService, "update_dataset", _update)
res = _run(module.update_dataset("dataset-4"))
assert res["code"] == 0
assert res["data"]["dataset_id"] == "dataset-4"
assert captured["dataset_id"] == "dataset-4"
assert "id" not in captured["kwargs"]
assert "tenant_id" not in captured["kwargs"]
assert "created_by" not in captured["kwargs"]
assert "create_time" not in captured["kwargs"]
_set_request_json(monkeypatch, module, {"name": "new-name"})
monkeypatch.setattr(module.EvaluationService, "update_dataset", lambda _dataset_id, **_kwargs: False)
res = _run(module.update_dataset("dataset-5"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "failed" in res["message"].lower()
def _raise_update(_dataset_id, **_kwargs):
raise RuntimeError("update boom")
monkeypatch.setattr(module.EvaluationService, "update_dataset", _raise_update)
res = _run(module.update_dataset("dataset-6"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "update boom" in res["message"]
monkeypatch.setattr(module.EvaluationService, "delete_dataset", lambda _dataset_id: False)
res = _run(module.delete_dataset("dataset-7"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "failed" in res["message"].lower()
monkeypatch.setattr(module.EvaluationService, "delete_dataset", lambda _dataset_id: True)
res = _run(module.delete_dataset("dataset-8"))
assert res["code"] == 0
assert res["data"]["dataset_id"] == "dataset-8"
def _raise_delete(_dataset_id):
raise RuntimeError("delete dataset boom")
monkeypatch.setattr(module.EvaluationService, "delete_dataset", _raise_delete)
res = _run(module.delete_dataset("dataset-9"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "delete dataset boom" in res["message"]
@pytest.mark.p2
def test_test_case_routes_matrix_unit(monkeypatch):
module = _load_evaluation_app(monkeypatch)
_set_request_json(monkeypatch, module, {"question": " "})
res = _run(module.add_test_case("dataset-1"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "question" in res["message"].lower()
_set_request_json(monkeypatch, module, {"question": "q1"})
monkeypatch.setattr(module.EvaluationService, "add_test_case", lambda **_kwargs: (False, "add failed"))
res = _run(module.add_test_case("dataset-2"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "add failed" in res["message"]
_set_request_json(
monkeypatch,
module,
{
"question": "q2",
"reference_answer": "a2",
"relevant_doc_ids": ["doc-1"],
"relevant_chunk_ids": ["chunk-1"],
"metadata": {"k": "v"},
},
)
monkeypatch.setattr(module.EvaluationService, "add_test_case", lambda **_kwargs: (True, "case-ok"))
res = _run(module.add_test_case("dataset-3"))
assert res["code"] == 0
assert res["data"]["case_id"] == "case-ok"
def _raise_add(**_kwargs):
raise RuntimeError("add case boom")
monkeypatch.setattr(module.EvaluationService, "add_test_case", _raise_add)
res = _run(module.add_test_case("dataset-4"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "add case boom" in res["message"]
_set_request_json(monkeypatch, module, {"cases": {}})
res = _run(module.import_test_cases("dataset-5"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "cases" in res["message"]
_set_request_json(monkeypatch, module, {"cases": [{"question": "q1"}, {"question": "q2"}]})
monkeypatch.setattr(module.EvaluationService, "import_test_cases", lambda **_kwargs: (2, 0))
res = _run(module.import_test_cases("dataset-6"))
assert res["code"] == 0
assert res["data"]["success_count"] == 2
assert res["data"]["failure_count"] == 0
assert res["data"]["total"] == 2
def _raise_import(**_kwargs):
raise RuntimeError("import boom")
monkeypatch.setattr(module.EvaluationService, "import_test_cases", _raise_import)
res = _run(module.import_test_cases("dataset-7"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "import boom" in res["message"]
monkeypatch.setattr(module.EvaluationService, "get_test_cases", lambda _dataset_id: [{"id": "case-1"}])
res = _run(module.get_test_cases("dataset-8"))
assert res["code"] == 0
assert res["data"]["total"] == 1
assert res["data"]["cases"][0]["id"] == "case-1"
def _raise_get_cases(_dataset_id):
raise RuntimeError("get cases boom")
monkeypatch.setattr(module.EvaluationService, "get_test_cases", _raise_get_cases)
res = _run(module.get_test_cases("dataset-9"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "get cases boom" in res["message"]
monkeypatch.setattr(module.EvaluationService, "delete_test_case", lambda _case_id: False)
res = _run(module.delete_test_case("case-1"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "failed" in res["message"].lower()
monkeypatch.setattr(module.EvaluationService, "delete_test_case", lambda _case_id: True)
res = _run(module.delete_test_case("case-2"))
assert res["code"] == 0
assert res["data"]["case_id"] == "case-2"
def _raise_delete_case(_case_id):
raise RuntimeError("delete case boom")
monkeypatch.setattr(module.EvaluationService, "delete_test_case", _raise_delete_case)
res = _run(module.delete_test_case("case-3"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "delete case boom" in res["message"]
@pytest.mark.p2
def test_run_and_recommendation_routes_matrix_unit(monkeypatch):
module = _load_evaluation_app(monkeypatch)
_set_request_json(monkeypatch, module, {"dataset_id": "d1", "dialog_id": "dialog-1", "name": "run 1"})
monkeypatch.setattr(module.EvaluationService, "start_evaluation", lambda **_kwargs: (False, "start failed"))
res = _run(module.start_evaluation())
assert res["code"] == module.RetCode.DATA_ERROR
assert "start failed" in res["message"]
monkeypatch.setattr(module.EvaluationService, "start_evaluation", lambda **_kwargs: (True, "run-ok"))
res = _run(module.start_evaluation())
assert res["code"] == 0
assert res["data"]["run_id"] == "run-ok"
def _raise_start(**_kwargs):
raise RuntimeError("start boom")
monkeypatch.setattr(module.EvaluationService, "start_evaluation", _raise_start)
res = _run(module.start_evaluation())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "start boom" in res["message"]
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: None)
res = _run(module.get_evaluation_run("run-1"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "not found" in res["message"].lower()
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: {"id": _run_id})
res = _run(module.get_evaluation_run("run-2"))
assert res["code"] == 0
assert res["data"]["id"] == "run-2"
def _raise_get_run(_run_id):
raise RuntimeError("get run boom")
monkeypatch.setattr(module.EvaluationService, "get_run_results", _raise_get_run)
res = _run(module.get_evaluation_run("run-3"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "get run boom" in res["message"]
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: None)
res = _run(module.get_run_results("run-4"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "not found" in res["message"].lower()
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: {"id": _run_id, "score": 0.9})
res = _run(module.get_run_results("run-5"))
assert res["code"] == 0
assert res["data"]["id"] == "run-5"
def _raise_results(_run_id):
raise RuntimeError("get results boom")
monkeypatch.setattr(module.EvaluationService, "get_run_results", _raise_results)
res = _run(module.get_run_results("run-6"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "get results boom" in res["message"]
res = _run(module.list_evaluation_runs())
assert res["code"] == 0
assert res["data"]["total"] == 0
def _raise_json_list(*_args, **_kwargs):
raise RuntimeError("list runs boom")
monkeypatch.setattr(module, "get_json_result", _raise_json_list)
res = _run(module.list_evaluation_runs())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "list runs boom" in res["message"]
monkeypatch.setattr(module, "get_json_result", lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data})
res = _run(module.delete_evaluation_run("run-7"))
assert res["code"] == 0
assert res["data"]["run_id"] == "run-7"
def _raise_json_delete(*_args, **_kwargs):
raise RuntimeError("delete run boom")
monkeypatch.setattr(module, "get_json_result", _raise_json_delete)
res = _run(module.delete_evaluation_run("run-8"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "delete run boom" in res["message"]
monkeypatch.setattr(module, "get_json_result", lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data})
monkeypatch.setattr(module.EvaluationService, "get_recommendations", lambda _run_id: [{"name": "cfg-1"}])
res = _run(module.get_recommendations("run-9"))
assert res["code"] == 0
assert res["data"]["recommendations"][0]["name"] == "cfg-1"
def _raise_recommend(_run_id):
raise RuntimeError("recommend boom")
monkeypatch.setattr(module.EvaluationService, "get_recommendations", _raise_recommend)
res = _run(module.get_recommendations("run-10"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "recommend boom" in res["message"]
@pytest.mark.p2
def test_compare_export_and_evaluate_single_matrix_unit(monkeypatch):
module = _load_evaluation_app(monkeypatch)
_set_request_json(monkeypatch, module, {"run_ids": ["run-1"]})
res = _run(module.compare_runs())
assert res["code"] == module.RetCode.DATA_ERROR
assert "at least 2" in res["message"]
_set_request_json(monkeypatch, module, {"run_ids": ["run-1", "run-2"]})
res = _run(module.compare_runs())
assert res["code"] == 0
assert res["data"]["comparison"] == {}
def _raise_json_compare(*_args, **_kwargs):
raise RuntimeError("compare boom")
monkeypatch.setattr(module, "get_json_result", _raise_json_compare)
_set_request_json(monkeypatch, module, {"run_ids": ["run-1", "run-2", "run-3"]})
res = _run(module.compare_runs())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "compare boom" in res["message"]
monkeypatch.setattr(module, "get_json_result", lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data})
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: None)
res = _run(module.export_results("run-11"))
assert res["code"] == module.RetCode.DATA_ERROR
assert "not found" in res["message"].lower()
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: {"id": _run_id, "rows": []})
res = _run(module.export_results("run-12"))
assert res["code"] == 0
assert res["data"]["id"] == "run-12"
def _raise_export(_run_id):
raise RuntimeError("export boom")
monkeypatch.setattr(module.EvaluationService, "get_run_results", _raise_export)
res = _run(module.export_results("run-13"))
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "export boom" in res["message"]
monkeypatch.setattr(module, "get_json_result", lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data})
res = _run(module.evaluate_single())
assert res["code"] == 0
assert res["data"]["answer"] == ""
assert res["data"]["metrics"] == {}
assert res["data"]["retrieved_chunks"] == []
def _raise_json_single(*_args, **_kwargs):
raise RuntimeError("single boom")
monkeypatch.setattr(module, "get_json_result", _raise_json_single)
res = _run(module.evaluate_single())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "single boom" in res["message"]

View File

@ -0,0 +1,367 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import functools
import importlib.util
import sys
from copy import deepcopy
from enum import Enum
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _DummyFile:
def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1):
self.id = file_id
self.type = file_type
self.name = name
self.location = location
self.size = size
class _FalsyFile(_DummyFile):
def __bool__(self):
return False
def _run(coro):
return asyncio.run(coro)
def _set_request_json(monkeypatch, module, payload_state):
async def _req_json():
return deepcopy(payload_state)
monkeypatch.setattr(module, "get_request_json", _req_json)
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_file2document_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="user-1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
api_pkg.apps = apps_mod
db_pkg = ModuleType("api.db")
db_pkg.__path__ = []
class _FileType(Enum):
FOLDER = "folder"
DOC = "doc"
db_pkg.FileType = _FileType
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
api_pkg.db = db_pkg
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
file2document_mod = ModuleType("api.db.services.file2document_service")
class _StubFile2DocumentService:
@staticmethod
def get_by_file_id(_file_id):
return []
@staticmethod
def delete_by_file_id(*_args, **_kwargs):
return None
@staticmethod
def insert(_payload):
return SimpleNamespace(to_json=lambda: {})
file2document_mod.File2DocumentService = _StubFile2DocumentService
monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_mod)
services_pkg.file2document_service = file2document_mod
file_service_mod = ModuleType("api.db.services.file_service")
class _StubFileService:
@staticmethod
def get_by_ids(_file_ids):
return []
@staticmethod
def get_all_innermost_file_ids(_file_id, _acc):
return []
@staticmethod
def get_by_id(_file_id):
return True, _DummyFile(_file_id, _FileType.DOC.value)
file_service_mod.FileService = _StubFileService
monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod)
services_pkg.file_service = file_service_mod
kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
class _StubKnowledgebaseService:
@staticmethod
def get_by_id(_kb_id):
return False, None
kb_service_mod.KnowledgebaseService = _StubKnowledgebaseService
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
services_pkg.knowledgebase_service = kb_service_mod
document_service_mod = ModuleType("api.db.services.document_service")
class _StubDocumentService:
@staticmethod
def get_by_id(doc_id):
return True, SimpleNamespace(id=doc_id)
@staticmethod
def get_tenant_id(_doc_id):
return "tenant-1"
@staticmethod
def remove_document(*_args, **_kwargs):
return True
@staticmethod
def insert(_payload):
return SimpleNamespace(id="doc-1")
document_service_mod.DocumentService = _StubDocumentService
monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
services_pkg.document_service = document_service_mod
api_utils_mod = ModuleType("api.utils.api_utils")
def get_json_result(data=None, message="", code=0):
return {"code": code, "data": data, "message": message}
def get_data_error_result(message=""):
return {"code": 102, "data": None, "message": message}
async def get_request_json():
return {}
def server_error_response(err):
return {"code": 500, "data": None, "message": str(err)}
def validate_request(*_keys):
def _decorator(func):
@functools.wraps(func)
async def _wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return _wrapper
return _decorator
api_utils_mod.get_json_result = get_json_result
api_utils_mod.get_data_error_result = get_data_error_result
api_utils_mod.get_request_json = get_request_json
api_utils_mod.server_error_response = server_error_response
api_utils_mod.validate_request = validate_request
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
misc_utils_mod = ModuleType("common.misc_utils")
misc_utils_mod.get_uuid = lambda: "uuid"
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
constants_mod = ModuleType("common.constants")
class _RetCode:
ARGUMENT_ERROR = 101
constants_mod.RetCode = _RetCode
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
module_name = "test_file2document_routes_unit_module"
module_path = repo_root / "api" / "apps" / "file2document_app.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_convert_branch_matrix_unit(monkeypatch):
module = _load_file2document_module(monkeypatch)
req_state = {"kb_ids": ["kb-1"], "file_ids": ["f1"]}
_set_request_json(monkeypatch, module, req_state)
events = {"deleted": []}
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_FalsyFile("f1", module.FileType.DOC.value)])
res = _run(module.convert())
assert res["message"] == "File not found!"
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("f1", module.FileType.DOC.value)])
monkeypatch.setattr(module.File2DocumentService, "get_by_file_id", lambda _file_id: [SimpleNamespace(document_id="doc-1")])
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
res = _run(module.convert())
assert res["message"] == "Document not found!"
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id=_doc_id)))
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: None)
res = _run(module.convert())
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
monkeypatch.setattr(module.DocumentService, "remove_document", lambda *_args, **_kwargs: False)
res = _run(module.convert())
assert "Document removal" in res["message"]
monkeypatch.setattr(module.DocumentService, "remove_document", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.File2DocumentService, "get_by_file_id", lambda _file_id: [])
monkeypatch.setattr(module.File2DocumentService, "delete_by_file_id", lambda file_id: events["deleted"].append(file_id))
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
res = _run(module.convert())
assert res["message"] == "Can't find this dataset!"
assert events["deleted"] == ["f1"]
kb = SimpleNamespace(id="kb-1", parser_id="naive", pipeline_id="p1", parser_config={})
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (False, None))
res = _run(module.convert())
assert res["message"] == "Can't find this file!"
req_state["file_ids"] = ["folder-1"]
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("folder-1", module.FileType.FOLDER.value, name="folder")])
monkeypatch.setattr(module.FileService, "get_all_innermost_file_ids", lambda _file_id, _acc: ["inner-1"])
monkeypatch.setattr(
module.FileService,
"get_by_id",
lambda _file_id: (True, _DummyFile("inner-1", module.FileType.DOC.value, name="inner.txt", location="inner.loc", size=2)),
)
monkeypatch.setattr(module.DocumentService, "insert", lambda _payload: SimpleNamespace(id="doc-new"))
monkeypatch.setattr(
module.File2DocumentService,
"insert",
lambda _payload: SimpleNamespace(to_json=lambda: {"file_id": "inner-1", "document_id": "doc-new"}),
)
res = _run(module.convert())
assert res["code"] == 0
assert res["data"] == [{"file_id": "inner-1", "document_id": "doc-new"}]
req_state["file_ids"] = ["f1"]
monkeypatch.setattr(
module.FileService,
"get_by_ids",
lambda _ids: (_ for _ in ()).throw(RuntimeError("convert boom")),
)
res = _run(module.convert())
assert res["code"] == 500
assert "convert boom" in res["message"]
@pytest.mark.p2
def test_rm_branch_matrix_unit(monkeypatch):
module = _load_file2document_module(monkeypatch)
req_state = {"file_ids": []}
_set_request_json(monkeypatch, module, req_state)
deleted = []
res = _run(module.rm())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert 'Lack of "Files ID"' in res["message"]
req_state["file_ids"] = ["f1"]
monkeypatch.setattr(module.File2DocumentService, "get_by_file_id", lambda _file_id: [])
res = _run(module.rm())
assert res["message"] == "Inform not found!"
monkeypatch.setattr(module.File2DocumentService, "get_by_file_id", lambda _file_id: [None])
res = _run(module.rm())
assert res["message"] == "Inform not found!"
monkeypatch.setattr(module.File2DocumentService, "get_by_file_id", lambda _file_id: [SimpleNamespace(document_id="doc-1")])
monkeypatch.setattr(module.File2DocumentService, "delete_by_file_id", lambda file_id: deleted.append(file_id))
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
res = _run(module.rm())
assert res["message"] == "Document not found!"
assert deleted == ["f1"]
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id=_doc_id)))
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: None)
res = _run(module.rm())
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
monkeypatch.setattr(module.DocumentService, "remove_document", lambda *_args, **_kwargs: False)
res = _run(module.rm())
assert "Document removal" in res["message"]
req_state["file_ids"] = ["f1", "f2"]
monkeypatch.setattr(
module.File2DocumentService,
"get_by_file_id",
lambda file_id: [SimpleNamespace(document_id=f"doc-{file_id}")],
)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda doc_id: (True, SimpleNamespace(id=doc_id)))
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
monkeypatch.setattr(module.DocumentService, "remove_document", lambda *_args, **_kwargs: True)
res = _run(module.rm())
assert res["code"] == 0
assert res["data"] is True
monkeypatch.setattr(
module.File2DocumentService,
"get_by_file_id",
lambda _file_id: (_ for _ in ()).throw(RuntimeError("rm boom")),
)
req_state["file_ids"] = ["boom"]
res = _run(module.rm())
assert res["code"] == 500
assert "rm boom" in res["message"]

File diff suppressed because it is too large Load Diff

View File

@ -26,6 +26,8 @@ from types import ModuleType, SimpleNamespace
import pytest
pytestmark = pytest.mark.filterwarnings("ignore:.*joblib will operate in serial mode.*:UserWarning")
class _DummyManager:
def route(self, *_args, **_kwargs):
@ -169,6 +171,16 @@ def _base_update_payload(**kwargs):
return payload
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
@pytest.mark.p2
def test_create_branches(monkeypatch):
module = _load_kb_module(monkeypatch)
@ -1046,3 +1058,236 @@ def test_unbind_task_branch_matrix(monkeypatch):
res = route()
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
assert "cannot delete task" in res["message"], res
@pytest.mark.p2
def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch):
module = _load_kb_module(monkeypatch)
route = inspect.unwrap(module.check_embedding)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-1")))
monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx")
class _FlipBool:
def __init__(self):
self._calls = 0
def __bool__(self):
self._calls += 1
return self._calls == 1
monkeypatch.setattr(
module.re,
"sub",
lambda _pattern, _repl, text: _FlipBool() if "TRIGGER_NO_TEXT" in str(text) else text,
)
def _fixed_sample(population, k):
return list(population)[:k]
monkeypatch.setattr(module.random, "sample", _fixed_sample)
class _DocStore:
def __init__(self, total, ids_by_offset, docs):
self.total = total
self.ids_by_offset = ids_by_offset
self.docs = docs
def search(self, select_fields, **kwargs):
if not select_fields:
return {"kind": "total"}
return {"kind": "sample", "offset": kwargs["offset"]}
def get_total(self, _res):
return self.total
def get_doc_ids(self, res):
return self.ids_by_offset.get(res.get("offset", -1), [])
def get(self, cid, _index_name, _kb_ids):
return self.docs.get(cid, {})
class _EmbModel:
def __init__(self):
self.calls = []
def encode(self, pair):
title, _txt = pair
self.calls.append(title)
if title == "Doc Mix":
# title+content mix wins over content only path.
return [module.np.array([1.0, 0.0]), module.np.array([0.0, 1.0])], None
if title == "Doc High":
return [module.np.array([1.0, 0.0]), module.np.array([1.0, 0.0])], None
return [module.np.array([0.0, 1.0]), module.np.array([0.0, 1.0])], None
emb_model = _EmbModel()
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: emb_model)
low_docs = {
"chunk-no-vec": {
"doc_id": "doc-no-vec",
"docnm_kwd": "Doc No Vec",
"content_with_weight": "body-no-vec",
"page_num_int": 1,
"position_int": 1,
"top_int": 1,
},
"chunk-bad-type": {
"doc_id": "doc-bad-type",
"docnm_kwd": "Doc Bad Type",
"content_with_weight": "body-bad-type",
"question_kwd": [],
"q_vec": {"bad": "type"},
"page_num_int": 1,
"position_int": 2,
"top_int": 2,
},
"chunk-low-zero": {
"doc_id": "doc-low-zero",
"docnm_kwd": "Doc Low Zero",
"content_with_weight": "body-low",
"question_kwd": [],
"q_vec": "0\t0",
"page_num_int": 1,
"position_int": 3,
"top_int": 3,
},
"chunk-no-text": {
"doc_id": "doc-no-text",
"docnm_kwd": "Doc No Text",
"content_with_weight": "TRIGGER_NO_TEXT",
"q_vec": [1.0, 0.0],
"page_num_int": 1,
"position_int": 4,
"top_int": 4,
},
"chunk-mix": {
"doc_id": "doc-mix",
"docnm_kwd": "Doc Mix",
"content_with_weight": "body-mix",
"q_vec": [1.0, 0.0],
"page_num_int": 1,
"position_int": 5,
"top_int": 5,
},
}
monkeypatch.setattr(
module.settings,
"docStoreConn",
_DocStore(
total=6,
ids_by_offset={
0: [],
1: ["chunk-no-vec"],
2: ["chunk-bad-type"],
3: ["chunk-low-zero"],
4: ["chunk-no-text"],
5: ["chunk-mix"],
},
docs=low_docs,
),
)
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 6})
res = _run(route())
assert res["code"] == module.RetCode.NOT_EFFECTIVE, res
assert "average similarity" in res["message"], res
summary = res["data"]["summary"]
assert summary["sampled"] == 5, summary
assert summary["valid"] == 2, summary
reasons = {item.get("reason") for item in res["data"]["results"] if "reason" in item}
assert "no_stored_vector" in reasons, res
assert "no_text" in reasons, res
assert any(item.get("chunk_id") == "chunk-low-zero" and "cos_sim" in item for item in res["data"]["results"]), res
assert summary["match_mode"] in {"content_only", "title+content"}, summary
high_docs = {
"chunk-high": {
"doc_id": "doc-high",
"docnm_kwd": "Doc High",
"content_with_weight": "body-high",
"q_vec": [1.0, 0.0],
"page_num_int": 1,
"position_int": 1,
"top_int": 1,
}
}
monkeypatch.setattr(
module.settings,
"docStoreConn",
_DocStore(total=1, ids_by_offset={0: ["chunk-high"]}, docs=high_docs),
)
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 1})
res = _run(route())
assert res["code"] == module.RetCode.SUCCESS, res
assert res["data"]["summary"]["avg_cos_sim"] > 0.9, res
@pytest.mark.p2
def test_check_embedding_error_and_empty_sample_paths_unit(monkeypatch):
module = _load_kb_module(monkeypatch)
route = inspect.unwrap(module.check_embedding)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-1")))
monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx")
monkeypatch.setattr(module.random, "sample", lambda population, k: list(population)[:k])
class _DocStore:
def __init__(self, total, ids_by_offset, docs):
self.total = total
self.ids_by_offset = ids_by_offset
self.docs = docs
def search(self, select_fields, **kwargs):
if not select_fields:
return {"kind": "total"}
return {"kind": "sample", "offset": kwargs["offset"]}
def get_total(self, _res):
return self.total
def get_doc_ids(self, res):
return self.ids_by_offset.get(res.get("offset", -1), [])
def get(self, cid, _index_name, _kb_ids):
return self.docs.get(cid, {})
class _BoomEmbModel:
def encode(self, _pair):
raise RuntimeError("encode boom")
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _BoomEmbModel())
monkeypatch.setattr(
module.settings,
"docStoreConn",
_DocStore(
total=1,
ids_by_offset={0: ["chunk-err"]},
docs={
"chunk-err": {
"doc_id": "doc-err",
"docnm_kwd": "Doc Err",
"content_with_weight": "body-err",
"q_vec": [1.0, 0.0],
"page_num_int": 1,
"position_int": 1,
"top_int": 1,
}
},
),
)
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 1})
res = _run(route())
assert res["code"] == module.RetCode.DATA_ERROR, res
assert "Embedding failure." in res["message"], res
assert "encode boom" in res["message"], res
class _OkEmbModel:
def encode(self, _pair):
return [module.np.array([1.0, 0.0]), module.np.array([1.0, 0.0])], None
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _OkEmbModel())
monkeypatch.setattr(module.settings, "docStoreConn", _DocStore(total=0, ids_by_offset={}, docs={}))
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 1})
with pytest.raises(UnboundLocalError):
_run(route())

View File

@ -16,6 +16,7 @@
import asyncio
import importlib.util
import json
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
@ -39,18 +40,39 @@ class _ExprField:
return (self.name, other)
class _StrEnum(str):
@property
def value(self):
return str(self)
class _DummyTenantLLMModel:
tenant_id = _ExprField("tenant_id")
llm_factory = _ExprField("llm_factory")
llm_name = _ExprField("llm_name")
class _TenantLLMRow:
def __init__(self, *, llm_name, llm_factory, model_type, api_key="key", status="1"):
def __init__(
self,
*,
llm_name,
llm_factory,
model_type,
api_key="key",
status="1",
used_tokens=0,
api_base="",
max_tokens=8192,
):
self.llm_name = llm_name
self.llm_factory = llm_factory
self.model_type = model_type
self.api_key = api_key
self.status = status
self.used_tokens = used_tokens
self.api_base = api_base
self.max_tokens = max_tokens
def to_dict(self):
return {
@ -58,15 +80,19 @@ class _TenantLLMRow:
"llm_factory": self.llm_factory,
"model_type": self.model_type,
"status": self.status,
"used_tokens": self.used_tokens,
"api_base": self.api_base,
"max_tokens": self.max_tokens,
}
class _LLMRow:
def __init__(self, *, llm_name, fid, model_type, status="1"):
def __init__(self, *, llm_name, fid, model_type, status="1", max_tokens=2048):
self.llm_name = llm_name
self.fid = fid
self.model_type = model_type
self.status = status
self.max_tokens = max_tokens
def to_dict(self):
return {
@ -74,6 +100,7 @@ class _LLMRow:
"fid": self.fid,
"model_type": self.model_type,
"status": self.status,
"max_tokens": self.max_tokens,
}
@ -81,6 +108,13 @@ def _run(coro):
return asyncio.run(coro)
def _set_request_json(monkeypatch, module, payload):
async def _get_request_json():
return dict(payload)
monkeypatch.setattr(module, "get_request_json", _get_request_json)
def _load_llm_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
@ -122,6 +156,10 @@ def _load_llm_app(monkeypatch):
def filter_delete(_filters):
return True
@staticmethod
def filter_update(_filters, _payload):
return True
tenant_llm_mod.LLMFactoriesService = _StubLLMFactoriesService
tenant_llm_mod.TenantLLMService = _StubTenantLLMService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_mod)
@ -164,13 +202,13 @@ def _load_llm_app(monkeypatch):
constants_mod = ModuleType("common.constants")
constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value="1"), INVALID=SimpleNamespace(value="0"))
constants_mod.LLMType = SimpleNamespace(
CHAT="chat",
EMBEDDING="embedding",
SPEECH2TEXT="speech2text",
IMAGE2TEXT="image2text",
RERANK="rerank",
TTS="tts",
OCR="ocr",
CHAT=_StrEnum("chat"),
EMBEDDING=_StrEnum("embedding"),
SPEECH2TEXT=_StrEnum("speech2text"),
IMAGE2TEXT=_StrEnum("image2text"),
RERANK=_StrEnum("rerank"),
TTS=_StrEnum("tts"),
OCR=_StrEnum("ocr"),
)
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
@ -179,7 +217,7 @@ def _load_llm_app(monkeypatch):
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
base64_mod = ModuleType("rag.utils.base64_image")
base64_mod.test_image = lambda _s: _s
base64_mod.test_image = b"image-bytes"
monkeypatch.setitem(sys.modules, "rag.utils.base64_image", base64_mod)
rag_llm_mod = ModuleType("rag.llm")
@ -288,3 +326,529 @@ def test_list_app_exception_path(monkeypatch):
res = _run(module.list_app())
assert res["code"] == 500
assert "query boom" in res["message"]
@pytest.mark.p2
def test_factories_route_success_and_exception_unit(monkeypatch):
module = _load_llm_app(monkeypatch)
def _factory(name):
return SimpleNamespace(name=name, to_dict=lambda n=name: {"name": n})
monkeypatch.setattr(
module,
"get_allowed_llm_factories",
lambda: [
_factory("OpenAI"),
_factory("CustomFactory"),
_factory("FastEmbed"),
_factory("Builtin"),
],
)
monkeypatch.setattr(
module.LLMService,
"get_all",
lambda: [
_LLMRow(llm_name="m1", fid="OpenAI", model_type="chat", status="1"),
_LLMRow(llm_name="m2", fid="OpenAI", model_type="embedding", status="1"),
_LLMRow(llm_name="m3", fid="OpenAI", model_type="rerank", status="0"),
],
)
res = module.factories()
assert res["code"] == 0
names = [item["name"] for item in res["data"]]
assert "FastEmbed" not in names
assert "Builtin" not in names
assert {"OpenAI", "CustomFactory"} == set(names)
openai = next(item for item in res["data"] if item["name"] == "OpenAI")
assert {"chat", "embedding"} == set(openai["model_types"])
monkeypatch.setattr(module, "get_allowed_llm_factories", lambda: (_ for _ in ()).throw(RuntimeError("factories boom")))
res = module.factories()
assert res["code"] == 500
assert "factories boom" in res["message"]
@pytest.mark.p2
def test_set_api_key_model_probe_matrix_unit(monkeypatch):
module = _load_llm_app(monkeypatch)
async def _wait_for(coro, *_args, **_kwargs):
return await coro
async def _to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
monkeypatch.setattr(module.asyncio, "wait_for", _wait_for)
monkeypatch.setattr(module.asyncio, "to_thread", _to_thread)
class _EmbeddingFail:
def __init__(self, *_args, **_kwargs):
pass
def encode(self, _texts):
return [[]], 1
class _EmbeddingPass:
def __init__(self, *_args, **_kwargs):
pass
def encode(self, _texts):
return [[0.1]], 1
class _ChatFail:
def __init__(self, *_args, **_kwargs):
pass
async def async_chat(self, *_args, **_kwargs):
return "**ERROR** chat fail", 1
class _RerankFail:
def __init__(self, *_args, **_kwargs):
pass
def similarity(self, *_args, **_kwargs):
return [], 0
factory = "FactoryA"
monkeypatch.setattr(
module.LLMService,
"query",
lambda **_kwargs: [
_LLMRow(llm_name="emb", fid=factory, model_type=module.LLMType.EMBEDDING.value, max_tokens=321),
_LLMRow(llm_name="chat", fid=factory, model_type=module.LLMType.CHAT.value, max_tokens=654),
_LLMRow(llm_name="rerank", fid=factory, model_type=module.LLMType.RERANK.value, max_tokens=987),
],
)
monkeypatch.setattr(module, "EmbeddingModel", {factory: _EmbeddingFail})
monkeypatch.setattr(module, "ChatModel", {factory: _ChatFail})
monkeypatch.setattr(module, "RerankModel", {factory: _RerankFail})
req = {"llm_factory": factory, "api_key": "k", "base_url": "http://x", "verify": True}
_set_request_json(monkeypatch, module, req)
res = _run(module.set_api_key())
assert res["code"] == 0
assert res["data"]["success"] is False
assert "Fail to access embedding model(emb)" in res["data"]["message"]
assert "Fail to access model(FactoryA/chat)" in res["data"]["message"]
assert "Fail to access model(FactoryA/rerank)" in res["data"]["message"]
req["verify"] = False
_set_request_json(monkeypatch, module, req)
res = _run(module.set_api_key())
assert res["code"] == 400
assert "Fail to access embedding model(emb)" in res["message"]
calls = {"filter_update": [], "save": []}
def _filter_update(filters, payload):
calls["filter_update"].append((filters, dict(payload)))
return False
def _save(**kwargs):
calls["save"].append(kwargs)
return True
monkeypatch.setattr(module, "EmbeddingModel", {factory: _EmbeddingPass})
monkeypatch.setattr(module.LLMService, "query", lambda **_kwargs: [_LLMRow(llm_name="emb-pass", fid=factory, model_type=module.LLMType.EMBEDDING.value, max_tokens=2049)])
monkeypatch.setattr(module.TenantLLMService, "filter_update", _filter_update)
monkeypatch.setattr(module.TenantLLMService, "save", _save)
success_req = {
"llm_factory": factory,
"api_key": "k2",
"base_url": "http://y",
"model_type": "chat",
"llm_name": "manual-model",
}
_set_request_json(monkeypatch, module, success_req)
res = _run(module.set_api_key())
assert res["code"] == 0
assert res["data"] is True
assert calls["filter_update"]
assert calls["filter_update"][0][1]["model_type"] == "chat"
assert calls["filter_update"][0][1]["llm_name"] == "manual-model"
assert calls["filter_update"][0][1]["max_tokens"] == 2049
assert calls["save"][0]["max_tokens"] == 2049
assert calls["save"][0]["llm_name"] == "emb-pass"
@pytest.mark.p2
def test_add_llm_factory_specific_key_assembly_unit(monkeypatch):
module = _load_llm_app(monkeypatch)
async def _wait_for(coro, *_args, **_kwargs):
return await coro
async def _to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
monkeypatch.setattr(module.asyncio, "wait_for", _wait_for)
monkeypatch.setattr(module.asyncio, "to_thread", _to_thread)
allowed = [
"VolcEngine",
"Tencent Cloud",
"Bedrock",
"LocalAI",
"HuggingFace",
"OpenAI-API-Compatible",
"VLLM",
"XunFei Spark",
"BaiduYiyan",
"Fish Audio",
"Google Cloud",
"Azure-OpenAI",
"OpenRouter",
"MinerU",
"PaddleOCR",
]
monkeypatch.setattr(module, "get_allowed_llm_factories", lambda: [SimpleNamespace(name=name) for name in allowed])
captured = {"chat": [], "tts": [], "filter_payloads": []}
class _ChatOK:
def __init__(self, key, model_name, base_url="", **_kwargs):
captured["chat"].append((key, model_name, base_url))
async def async_chat(self, *_args, **_kwargs):
return "ok", 1
class _TTSOK:
def __init__(self, key, model_name, base_url="", **_kwargs):
captured["tts"].append((key, model_name, base_url))
def tts(self, _text):
yield b"ok"
monkeypatch.setattr(module, "ChatModel", {name: _ChatOK for name in allowed})
monkeypatch.setattr(module, "TTSModel", {"XunFei Spark": _TTSOK})
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, payload: captured["filter_payloads"].append(dict(payload)) or True)
reject_req = {"llm_factory": "NotAllowed", "llm_name": "x", "model_type": module.LLMType.CHAT.value}
_set_request_json(monkeypatch, module, reject_req)
res = _run(module.add_llm())
assert res["code"] == 400
assert "is not allowed" in res["message"]
def _run_case(factory, *, model_type=module.LLMType.CHAT.value, extra=None):
req = {"llm_factory": factory, "llm_name": "model", "model_type": model_type, "api_key": "k", "api_base": "http://api"}
if extra:
req.update(extra)
_set_request_json(monkeypatch, module, req)
out = _run(module.add_llm())
assert out["code"] == 0
assert out["data"] is True
return captured["filter_payloads"][-1]
volc = _run_case("VolcEngine", extra={"ark_api_key": "ak", "endpoint_id": "eid"})
assert json.loads(volc["api_key"]) == {"ark_api_key": "ak", "endpoint_id": "eid"}
bedrock = _run_case(
"Bedrock",
extra={"auth_mode": "iam", "bedrock_ak": "ak", "bedrock_sk": "sk", "bedrock_region": "r", "aws_role_arn": "arn"},
)
assert json.loads(bedrock["api_key"]) == {
"auth_mode": "iam",
"bedrock_ak": "ak",
"bedrock_sk": "sk",
"bedrock_region": "r",
"aws_role_arn": "arn",
}
localai = _run_case("LocalAI")
assert localai["llm_name"] == "model___LocalAI"
huggingface = _run_case("HuggingFace")
assert huggingface["llm_name"] == "model___HuggingFace"
openapi = _run_case("OpenAI-API-Compatible")
assert openapi["llm_name"] == "model___OpenAI-API"
vllm = _run_case("VLLM")
assert vllm["llm_name"] == "model___VLLM"
spark_chat = _run_case("XunFei Spark", extra={"spark_api_password": "spark-pass"})
assert spark_chat["api_key"] == "spark-pass"
spark_tts = _run_case(
"XunFei Spark",
model_type=module.LLMType.TTS.value,
extra={"spark_app_id": "app", "spark_api_secret": "secret", "spark_api_key": "key"},
)
assert json.loads(spark_tts["api_key"]) == {
"spark_app_id": "app",
"spark_api_secret": "secret",
"spark_api_key": "key",
}
baidu = _run_case("BaiduYiyan", extra={"yiyan_ak": "ak", "yiyan_sk": "sk"})
assert json.loads(baidu["api_key"]) == {"yiyan_ak": "ak", "yiyan_sk": "sk"}
fish = _run_case("Fish Audio", extra={"fish_audio_ak": "ak", "fish_audio_refid": "rid"})
assert json.loads(fish["api_key"]) == {"fish_audio_ak": "ak", "fish_audio_refid": "rid"}
google = _run_case(
"Google Cloud",
extra={"google_project_id": "pid", "google_region": "us", "google_service_account_key": "sak"},
)
assert json.loads(google["api_key"]) == {
"google_project_id": "pid",
"google_region": "us",
"google_service_account_key": "sak",
}
azure = _run_case("Azure-OpenAI", extra={"api_key": "real-key", "api_version": "2024-01-01"})
assert json.loads(azure["api_key"]) == {"api_key": "real-key", "api_version": "2024-01-01"}
openrouter = _run_case("OpenRouter", extra={"api_key": "or-key", "provider_order": "a,b"})
assert json.loads(openrouter["api_key"]) == {"api_key": "or-key", "provider_order": "a,b"}
mineru = _run_case("MinerU", extra={"api_key": "m-key", "provider_order": "p1"})
assert json.loads(mineru["api_key"]) == {"api_key": "m-key", "provider_order": "p1"}
paddle = _run_case("PaddleOCR", extra={"api_key": "p-key", "provider_order": "p2"})
assert json.loads(paddle["api_key"]) == {"api_key": "p-key", "provider_order": "p2"}
tencent_req = {
"llm_factory": "Tencent Cloud",
"llm_name": "model",
"model_type": module.LLMType.CHAT.value,
"tencent_cloud_sid": "sid",
"tencent_cloud_sk": "sk",
}
async def _tencent_request_json():
return tencent_req
monkeypatch.setattr(module, "get_request_json", _tencent_request_json)
delegated = {}
async def _fake_set_api_key():
delegated["api_key"] = tencent_req.get("api_key")
return {"code": 0, "data": "delegated"}
monkeypatch.setattr(module, "set_api_key", _fake_set_api_key)
res = _run(module.add_llm())
assert res["code"] == 0
assert res["data"] == "delegated"
assert json.loads(delegated["api_key"]) == {"tencent_cloud_sid": "sid", "tencent_cloud_sk": "sk"}
@pytest.mark.p2
def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch):
module = _load_llm_app(monkeypatch)
async def _wait_for(coro, *_args, **_kwargs):
return await coro
async def _to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
monkeypatch.setattr(module.asyncio, "wait_for", _wait_for)
monkeypatch.setattr(module.asyncio, "to_thread", _to_thread)
monkeypatch.setattr(
module,
"get_allowed_llm_factories",
lambda: [
SimpleNamespace(name=name)
for name in [
"FEmbFail",
"FEmbPass",
"FChatFail",
"FChatPass",
"FRKey",
"FRFail",
"FImgFail",
"FTTSFail",
"FOcrFail",
"FSttFail",
"FUnknown",
]
],
)
class _EmbeddingFail:
def __init__(self, *_args, **_kwargs):
pass
def encode(self, _texts):
return [[]], 1
class _EmbeddingPass:
def __init__(self, *_args, **_kwargs):
pass
def encode(self, _texts):
return [[0.5]], 1
class _ChatFail:
def __init__(self, *_args, **_kwargs):
pass
async def async_chat(self, *_args, **_kwargs):
return "**ERROR**: chat failed", 0
class _ChatPass:
def __init__(self, *_args, **_kwargs):
pass
async def async_chat(self, *_args, **_kwargs):
return "ok", 1
class _RerankFail:
def __init__(self, *_args, **_kwargs):
pass
def similarity(self, *_args, **_kwargs):
return [], 1
class _CvFail:
def __init__(self, *_args, **_kwargs):
pass
def describe(self, _image_data):
return "**ERROR**: image failed", 0
class _TTSFail:
def __init__(self, *_args, **_kwargs):
pass
def tts(self, _text):
raise RuntimeError("tts fail")
yield b"x"
class _OcrFail:
def __init__(self, *_args, **_kwargs):
pass
def check_available(self):
return False, "ocr unavailable"
class _SttFail:
def __init__(self, *_args, **_kwargs):
raise RuntimeError("stt fail")
class _RerankKeyMap(dict):
def __contains__(self, key):
if key == "FRKey":
return True
return super().__contains__(key)
def __getitem__(self, key):
if key == "FRKey":
raise KeyError("rerank key fail")
return super().__getitem__(key)
monkeypatch.setattr(module, "EmbeddingModel", {"FEmbFail": _EmbeddingFail, "FEmbPass": _EmbeddingPass})
monkeypatch.setattr(module, "ChatModel", {"FChatFail": _ChatFail, "FChatPass": _ChatPass})
monkeypatch.setattr(module, "RerankModel", _RerankKeyMap({"FRFail": _RerankFail}))
monkeypatch.setattr(module, "CvModel", {"FImgFail": _CvFail})
monkeypatch.setattr(module, "TTSModel", {"FTTSFail": _TTSFail})
monkeypatch.setattr(module, "OcrModel", {"FOcrFail": _OcrFail})
monkeypatch.setattr(module, "Seq2txtModel", {"FSttFail": _SttFail})
def _call(req):
_set_request_json(monkeypatch, module, req)
return _run(module.add_llm())
res = _call({"llm_factory": "FEmbFail", "llm_name": "m", "model_type": module.LLMType.EMBEDDING.value, "verify": True})
assert res["code"] == 0
assert res["data"]["success"] is False
assert "Fail to access embedding model(m)." in res["data"]["message"]
res = _call({"llm_factory": "FEmbFail", "llm_name": "m", "model_type": module.LLMType.EMBEDDING.value})
assert res["code"] == 400
assert "Fail to access embedding model(m)." in res["message"]
res = _call({"llm_factory": "FChatFail", "llm_name": "m", "model_type": module.LLMType.CHAT.value, "verify": True})
assert res["code"] == 0
assert "Fail to access model(FChatFail/m)." in res["data"]["message"]
res = _call({"llm_factory": "FRKey", "llm_name": "m", "model_type": module.LLMType.RERANK.value, "verify": True})
assert res["code"] == 0
assert "dose not support this model(FRKey/m)" in res["data"]["message"]
res = _call({"llm_factory": "FRFail", "llm_name": "m", "model_type": module.LLMType.RERANK.value, "verify": True})
assert res["code"] == 0
assert "Fail to access model(FRFail/m)." in res["data"]["message"]
res = _call({"llm_factory": "FImgFail", "llm_name": "m", "model_type": module.LLMType.IMAGE2TEXT.value, "verify": True})
assert res["code"] == 0
assert "Fail to access model(FImgFail/m)." in res["data"]["message"]
res = _call({"llm_factory": "FTTSFail", "llm_name": "m", "model_type": module.LLMType.TTS.value, "verify": True})
assert res["code"] == 0
assert "Fail to access model(FTTSFail/m)." in res["data"]["message"]
res = _call({"llm_factory": "FOcrFail", "llm_name": "m", "model_type": module.LLMType.OCR.value, "verify": True})
assert res["code"] == 0
assert "Fail to access model(FOcrFail/m)." in res["data"]["message"]
res = _call({"llm_factory": "FSttFail", "llm_name": "m", "model_type": module.LLMType.SPEECH2TEXT.value, "verify": True})
assert res["code"] == 0
assert "Fail to access model(FSttFail/m)." in res["data"]["message"]
_set_request_json(monkeypatch, module, {"llm_factory": "FUnknown", "llm_name": "m", "model_type": "unknown"})
with pytest.raises(RuntimeError, match="Unknown model type: unknown"):
_run(module.add_llm())
saved = []
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, _payload: False)
monkeypatch.setattr(module.TenantLLMService, "save", lambda **kwargs: saved.append(kwargs) or True)
res = _call({"llm_factory": "FChatPass", "llm_name": "m", "model_type": module.LLMType.CHAT.value, "api_key": "k"})
assert res["code"] == 0
assert res["data"] is True
assert saved
assert saved[0]["llm_factory"] == "FChatPass"
@pytest.mark.p2
def test_llm_mutation_routes_unit(monkeypatch):
module = _load_llm_app(monkeypatch)
calls = {"delete": [], "update": []}
monkeypatch.setattr(module.TenantLLMService, "filter_delete", lambda filters: calls["delete"].append(filters) or True)
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda filters, payload: calls["update"].append((filters, payload)) or True)
_set_request_json(monkeypatch, module, {"llm_factory": "OpenAI", "llm_name": "gpt"})
res = _run(module.delete_llm())
assert res["code"] == 0
assert res["data"] is True
_set_request_json(monkeypatch, module, {"llm_factory": "OpenAI", "llm_name": "gpt", "status": 0})
res = _run(module.enable_llm())
assert res["code"] == 0
assert res["data"] is True
assert calls["update"][0][1]["status"] == "0"
_set_request_json(monkeypatch, module, {"llm_factory": "OpenAI"})
res = _run(module.delete_factory())
assert res["code"] == 0
assert res["data"] is True
assert len(calls["delete"]) == 2
@pytest.mark.p2
def test_my_llms_include_details_and_exception_unit(monkeypatch):
module = _load_llm_app(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args={"include_details": "true"}))
ensure_calls = []
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda tenant_id: ensure_calls.append(tenant_id))
monkeypatch.setattr(
module.TenantLLMService,
"query",
lambda **_kwargs: [
_TenantLLMRow(
llm_name="chat-model",
llm_factory="FactoryX",
model_type="chat",
used_tokens=42,
api_base="",
max_tokens=4096,
status="1",
)
],
)
monkeypatch.setattr(module.LLMFactoriesService, "query", lambda **_kwargs: [SimpleNamespace(name="FactoryX", tags=["tag-a"])])
res = module.my_llms()
assert res["code"] == 0
assert ensure_calls == ["tenant-1"]
assert "FactoryX" in res["data"]
assert res["data"]["FactoryX"]["tags"] == ["tag-a"]
assert res["data"]["FactoryX"]["llm"][0]["used_token"] == 42
assert res["data"]["FactoryX"]["llm"][0]["max_tokens"] == 4096
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: (_ for _ in ()).throw(RuntimeError("my llms boom")))
res = module.my_llms()
assert res["code"] == 500
assert "my llms boom" in res["message"]

View File

@ -16,6 +16,7 @@
import asyncio
import importlib.util
import inspect
import json
import sys
from functools import wraps
from pathlib import Path
@ -131,6 +132,16 @@ def _set_request_json(monkeypatch, module, payload):
monkeypatch.setattr(module, "get_request_json", _request_json)
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_mcp_server_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
@ -197,6 +208,28 @@ def _load_mcp_server_app(monkeypatch):
api_utils_mod.get_mcp_tools = _get_mcp_tools
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
web_utils_mod = ModuleType("api.utils.web_utils")
def _get_float(data, key, default):
try:
return float(data.get(key, default))
except (TypeError, ValueError):
return default
def _safe_json_parse(value):
if isinstance(value, (dict, list)):
return value
if value in (None, ""):
return {}
try:
return json.loads(value)
except (TypeError, ValueError):
return {}
web_utils_mod.get_float = _get_float
web_utils_mod.safe_json_parse = _safe_json_parse
monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
module_name = "test_mcp_server_app_unit_module"
module_path = repo_root / "api" / "apps" / "mcp_server_app.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
@ -706,3 +739,159 @@ def test_test_tool_missing_mcp_id(monkeypatch):
_set_request_json(monkeypatch, module, {"mcp_id": "", "tool_name": "tool_a", "arguments": {"x": 1}})
res = _run(module.test_tool.__wrapped__())
assert "No MCP server ID provided" in res["message"]
@pytest.mark.p2
def test_test_tool_route_matrix_unit(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcp_id": "", "tool_name": "tool_a", "arguments": {"x": 1}})
res = _run(module.test_tool.__wrapped__())
assert "No MCP server ID provided" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "", "arguments": {"x": 1}})
res = _run(module.test_tool.__wrapped__())
assert "Require provide tool name and arguments" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "tool_a", "arguments": {}})
res = _run(module.test_tool.__wrapped__())
assert "Require provide tool name and arguments" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "tool_a", "arguments": {"x": 1}})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
res = _run(module.test_tool.__wrapped__())
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
server_other = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="other", variables={})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_other))
res = _run(module.test_tool.__wrapped__())
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
server_ok = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_ok))
close_calls = []
async def _thread_pool_exec_success(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls.append(args[0])
return None
return func(*args)
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
res = _run(module.test_tool.__wrapped__())
assert res["code"] == 0
assert res["data"] == "ok"
assert close_calls and len(close_calls[-1]) == 1
async def _thread_pool_exec_raise(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
return None
raise RuntimeError("tool call explode")
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_raise)
res = _run(module.test_tool.__wrapped__())
assert res["code"] == 100
assert "tool call explode" in res["message"]
@pytest.mark.p2
def test_cache_tool_route_matrix_unit(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcp_id": "", "tools": [{"name": "tool_a"}]})
res = _run(module.cache_tool.__wrapped__())
assert "No MCP server ID provided" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tools": [{"name": "tool_a"}]})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
res = _run(module.cache_tool.__wrapped__())
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
server_other = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="other", variables={})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_other))
res = _run(module.cache_tool.__wrapped__())
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
server_fail = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_fail))
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False)
res = _run(module.cache_tool.__wrapped__())
assert "Failed to updated MCP server" in res["message"]
server_ok = _DummyMCPServer(
id="id1",
name="srv",
url="http://a",
server_type="sse",
tenant_id="tenant_1",
variables={"tools": {"old_tool": {"name": "old_tool"}}},
)
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_ok))
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True)
_set_request_json(
monkeypatch,
module,
{
"mcp_id": "id1",
"tools": [{"name": "tool_a", "enabled": True}, {"bad": 1}, "x", {"name": "tool_b", "enabled": False}],
},
)
res = _run(module.cache_tool.__wrapped__())
assert res["code"] == 0
assert sorted(res["data"].keys()) == ["tool_a", "tool_b"]
assert server_ok.variables["tools"]["tool_b"]["enabled"] is False
@pytest.mark.p2
def test_test_mcp_route_matrix_unit(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"})
res = _run(module.test_mcp.__wrapped__())
assert "Invalid MCP url" in res["message"]
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"})
res = _run(module.test_mcp.__wrapped__())
assert "Unsupported MCP server type" in res["message"]
close_calls = []
async def _thread_pool_exec_inner_error(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls.append(args[0])
return None
if getattr(func, "__name__", "") == "get_tools":
raise RuntimeError("get tools explode")
return func(*args)
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error)
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
res = _run(module.test_mcp.__wrapped__())
assert res["code"] == 102
assert "Test MCP error: get tools explode" in res["message"]
assert close_calls and len(close_calls[-1]) == 1
close_calls_success = []
async def _thread_pool_exec_success(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls_success.append(args[0])
return None
return func(*args)
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
res = _run(module.test_mcp.__wrapped__())
assert res["code"] == 0
assert res["data"][0]["name"] == "tool_a"
assert all(tool["enabled"] is True for tool in res["data"])
assert close_calls_success and len(close_calls_success[-1]) == 1
def _raise_session(*_args, **_kwargs):
raise RuntimeError("session explode")
monkeypatch.setattr(module, "MCPToolCallSession", _raise_session)
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
res = _run(module.test_mcp.__wrapped__())
assert res["code"] == 100
assert "session explode" in res["message"]

View File

@ -0,0 +1,509 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from copy import deepcopy
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _DummyAtomic:
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc, _tb):
return False
class _Args(dict):
def get(self, key, default=None):
return super().get(key, default)
class _EnumValue:
def __init__(self, value):
self.value = value
class _DummyStatusEnum:
VALID = _EnumValue("1")
class _DummyRetCode:
SUCCESS = 0
EXCEPTION_ERROR = 100
ARGUMENT_ERROR = 101
DATA_ERROR = 102
OPERATING_ERROR = 103
AUTHENTICATION_ERROR = 109
class _SearchRecord:
def __init__(self, search_id="search-1", name="search", search_config=None):
self.id = search_id
self.name = name
self.search_config = {} if search_config is None else dict(search_config)
def to_dict(self):
return {"id": self.id, "name": self.name, "search_config": dict(self.search_config)}
def _run(coro):
return asyncio.run(coro)
def _set_request_json(monkeypatch, module, payload):
async def _request_json():
return deepcopy(payload)
monkeypatch.setattr(module, "get_request_json", _request_json)
def _set_request_args(monkeypatch, module, args=None):
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args or {})))
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_search_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
quart_mod = ModuleType("quart")
quart_mod.request = SimpleNamespace(args=_Args())
monkeypatch.setitem(sys.modules, "quart", quart_mod)
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
misc_utils_mod = ModuleType("common.misc_utils")
misc_utils_mod.get_uuid = lambda: "search-uuid-1"
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
common_pkg.misc_utils = misc_utils_mod
constants_mod = ModuleType("common.constants")
constants_mod.RetCode = _DummyRetCode
constants_mod.StatusEnum = _DummyStatusEnum
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
common_pkg.constants = constants_mod
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="tenant-1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
api_pkg.apps = apps_mod
constants_api_mod = ModuleType("api.constants")
constants_api_mod.DATASET_NAME_LIMIT = 255
monkeypatch.setitem(sys.modules, "api.constants", constants_api_mod)
db_pkg = ModuleType("api.db")
db_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
api_pkg.db = db_pkg
db_models_mod = ModuleType("api.db.db_models")
class _DummyDB:
@staticmethod
def atomic():
return _DummyAtomic()
db_models_mod.DB = _DummyDB
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
services_pkg.duplicate_name = lambda _checker, **kwargs: kwargs.get("name", "")
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
search_service_mod = ModuleType("api.db.services.search_service")
class _SearchService:
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def save(**_kwargs):
return True
@staticmethod
def accessible4deletion(_search_id, _user_id):
return True
@staticmethod
def update_by_id(_search_id, _req):
return True
@staticmethod
def get_by_id(_search_id):
return True, _SearchRecord(search_id=_search_id, name="updated")
@staticmethod
def get_detail(_search_id):
return {"id": _search_id}
@staticmethod
def get_by_tenant_ids(_tenants, _user_id, _page_number, _items_per_page, _orderby, _desc, _keywords):
return [], 0
@staticmethod
def delete_by_id(_search_id):
return True
search_service_mod.SearchService = _SearchService
monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod)
user_service_mod = ModuleType("api.db.services.user_service")
class _TenantService:
@staticmethod
def get_by_id(_tenant_id):
return True, SimpleNamespace(id=_tenant_id)
class _UserTenantService:
@staticmethod
def query(**_kwargs):
return [SimpleNamespace(tenant_id="tenant-1")]
user_service_mod.TenantService = _TenantService
user_service_mod.UserTenantService = _UserTenantService
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
utils_pkg = ModuleType("api.utils")
utils_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.utils", utils_pkg)
api_utils_mod = ModuleType("api.utils.api_utils")
async def _default_request_json():
return {}
def _get_data_error_result(code=_DummyRetCode.DATA_ERROR, message="Sorry! Data missing!"):
return {"code": code, "message": message}
def _get_json_result(code=_DummyRetCode.SUCCESS, message="success", data=None):
return {"code": code, "message": message, "data": data}
def _server_error_response(error):
return {"code": _DummyRetCode.EXCEPTION_ERROR, "message": repr(error)}
def _validate_request(*_args, **_kwargs):
def _decorator(func):
return func
return _decorator
def _not_allowed_parameters(*_params):
def _decorator(func):
return func
return _decorator
api_utils_mod.get_request_json = _default_request_json
api_utils_mod.get_data_error_result = _get_data_error_result
api_utils_mod.get_json_result = _get_json_result
api_utils_mod.server_error_response = _server_error_response
api_utils_mod.validate_request = _validate_request
api_utils_mod.not_allowed_parameters = _not_allowed_parameters
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
utils_pkg.api_utils = api_utils_mod
module_name = "test_search_routes_unit_module"
module_path = repo_root / "api" / "apps" / "search_app.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_create_route_matrix_unit(monkeypatch):
module = _load_search_app(monkeypatch)
_set_request_json(monkeypatch, module, {"name": 1})
res = _run(module.create())
assert res["code"] == module.RetCode.DATA_ERROR
assert "must be string" in res["message"]
_set_request_json(monkeypatch, module, {"name": " "})
res = _run(module.create())
assert res["code"] == module.RetCode.DATA_ERROR
assert "empty" in res["message"].lower()
_set_request_json(monkeypatch, module, {"name": "a" * 256})
res = _run(module.create())
assert res["code"] == module.RetCode.DATA_ERROR
assert "255" in res["message"]
_set_request_json(monkeypatch, module, {"name": "create-auth-fail"})
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (False, None))
res = _run(module.create())
assert res["code"] == module.RetCode.DATA_ERROR
assert "authorized identity" in res["message"].lower()
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (True, SimpleNamespace(id=_tenant_id)))
monkeypatch.setattr(module, "duplicate_name", lambda _checker, **kwargs: kwargs["name"] + "_dedup")
_set_request_json(monkeypatch, module, {"name": "create-fail", "description": "d"})
monkeypatch.setattr(module.SearchService, "save", lambda **_kwargs: False)
res = _run(module.create())
assert res["code"] == module.RetCode.DATA_ERROR
_set_request_json(monkeypatch, module, {"name": "create-ok", "description": "d"})
monkeypatch.setattr(module.SearchService, "save", lambda **_kwargs: True)
res = _run(module.create())
assert res["code"] == 0
assert res["data"]["search_id"] == "search-uuid-1"
def _raise_save(**_kwargs):
raise RuntimeError("save boom")
monkeypatch.setattr(module.SearchService, "save", _raise_save)
_set_request_json(monkeypatch, module, {"name": "create-exception", "description": "d"})
res = _run(module.create())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "save boom" in res["message"]
@pytest.mark.p2
def test_update_and_detail_route_matrix_unit(monkeypatch):
module = _load_search_app(monkeypatch)
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": 1, "search_config": {}, "tenant_id": "tenant-1"})
res = _run(module.update())
assert res["code"] == module.RetCode.DATA_ERROR
assert "must be string" in res["message"]
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": " ", "search_config": {}, "tenant_id": "tenant-1"})
res = _run(module.update())
assert res["code"] == module.RetCode.DATA_ERROR
assert "empty" in res["message"].lower()
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "a" * 256, "search_config": {}, "tenant_id": "tenant-1"})
res = _run(module.update())
assert res["code"] == module.RetCode.DATA_ERROR
assert "large than" in res["message"]
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "ok", "search_config": {}, "tenant_id": "tenant-1"})
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (False, None))
res = _run(module.update())
assert res["code"] == module.RetCode.DATA_ERROR
assert "authorized identity" in res["message"].lower()
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (True, SimpleNamespace(id=_tenant_id)))
monkeypatch.setattr(module.SearchService, "accessible4deletion", lambda _search_id, _user_id: False)
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "ok", "search_config": {}, "tenant_id": "tenant-1"})
res = _run(module.update())
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
assert "authorization" in res["message"].lower()
monkeypatch.setattr(module.SearchService, "accessible4deletion", lambda _search_id, _user_id: True)
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [None])
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "ok", "search_config": {}, "tenant_id": "tenant-1"})
res = _run(module.update())
assert res["code"] == module.RetCode.DATA_ERROR
assert "cannot find search" in res["message"].lower()
existing = _SearchRecord(search_id="s1", name="old-name", search_config={"existing": 1})
def _query_duplicate(**kwargs):
if "id" in kwargs:
return [existing]
if "name" in kwargs:
return [SimpleNamespace(id="dup")]
return []
monkeypatch.setattr(module.SearchService, "query", _query_duplicate)
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "new-name", "search_config": {}, "tenant_id": "tenant-1"})
res = _run(module.update())
assert res["code"] == module.RetCode.DATA_ERROR
assert "duplicated" in res["message"].lower()
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [existing])
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "old-name", "search_config": [], "tenant_id": "tenant-1"})
res = _run(module.update())
assert res["code"] == module.RetCode.DATA_ERROR
assert "json object" in res["message"].lower()
captured = {}
def _update_fail(search_id, req):
captured["search_id"] = search_id
captured["req"] = dict(req)
return False
monkeypatch.setattr(module.SearchService, "update_by_id", _update_fail)
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "old-name", "search_config": {"top_k": 3}, "tenant_id": "tenant-1"})
res = _run(module.update())
assert res["code"] == module.RetCode.DATA_ERROR
assert "failed to update" in res["message"].lower()
assert captured["search_id"] == "s1"
assert "search_id" not in captured["req"]
assert "tenant_id" not in captured["req"]
assert captured["req"]["search_config"] == {"existing": 1, "top_k": 3}
monkeypatch.setattr(module.SearchService, "update_by_id", lambda _search_id, _req: True)
monkeypatch.setattr(module.SearchService, "get_by_id", lambda _search_id: (False, None))
res = _run(module.update())
assert res["code"] == module.RetCode.DATA_ERROR
assert "failed to fetch" in res["message"].lower()
monkeypatch.setattr(
module.SearchService,
"get_by_id",
lambda _search_id: (True, _SearchRecord(search_id=_search_id, name="old-name", search_config={"existing": 1, "top_k": 3})),
)
res = _run(module.update())
assert res["code"] == 0
assert res["data"]["id"] == "s1"
def _raise_query(**_kwargs):
raise RuntimeError("update boom")
monkeypatch.setattr(module.SearchService, "query", _raise_query)
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "old-name", "search_config": {"top_k": 3}, "tenant_id": "tenant-1"})
res = _run(module.update())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "update boom" in res["message"]
_set_request_args(monkeypatch, module, {"search_id": "s1"})
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-a")])
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [])
res = module.detail()
assert res["code"] == module.RetCode.OPERATING_ERROR
assert "permission" in res["message"].lower()
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [SimpleNamespace(id="s1")])
monkeypatch.setattr(module.SearchService, "get_detail", lambda _search_id: None)
res = module.detail()
assert res["code"] == module.RetCode.DATA_ERROR
assert "can't find" in res["message"].lower()
monkeypatch.setattr(module.SearchService, "get_detail", lambda _search_id: {"id": _search_id, "name": "detail-name"})
res = module.detail()
assert res["code"] == 0
assert res["data"]["id"] == "s1"
def _raise_detail(_search_id):
raise RuntimeError("detail boom")
monkeypatch.setattr(module.SearchService, "get_detail", _raise_detail)
res = module.detail()
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "detail boom" in res["message"]
@pytest.mark.p2
def test_list_and_rm_route_matrix_unit(monkeypatch):
module = _load_search_app(monkeypatch)
_set_request_args(
monkeypatch,
module,
{"keywords": "k", "page": "1", "page_size": "2", "orderby": "create_time", "desc": "false"},
)
_set_request_json(monkeypatch, module, {"owner_ids": []})
monkeypatch.setattr(
module.SearchService,
"get_by_tenant_ids",
lambda _tenants, _uid, _page, _size, _orderby, _desc, _keywords: ([{"id": "a", "tenant_id": "tenant-1"}], 1),
)
res = _run(module.list_search_app())
assert res["code"] == 0
assert res["data"]["total"] == 1
assert res["data"]["search_apps"][0]["id"] == "a"
_set_request_args(
monkeypatch,
module,
{"keywords": "k", "page": "1", "page_size": "1", "orderby": "create_time", "desc": "true"},
)
_set_request_json(monkeypatch, module, {"owner_ids": ["tenant-1"]})
monkeypatch.setattr(
module.SearchService,
"get_by_tenant_ids",
lambda _tenants, _uid, _page, _size, _orderby, _desc, _keywords: (
[{"id": "x", "tenant_id": "tenant-1"}, {"id": "y", "tenant_id": "tenant-2"}],
2,
),
)
res = _run(module.list_search_app())
assert res["code"] == 0
assert res["data"]["total"] == 1
assert len(res["data"]["search_apps"]) == 1
assert res["data"]["search_apps"][0]["tenant_id"] == "tenant-1"
def _raise_list(*_args, **_kwargs):
raise RuntimeError("list boom")
monkeypatch.setattr(module.SearchService, "get_by_tenant_ids", _raise_list)
_set_request_json(monkeypatch, module, {"owner_ids": []})
res = _run(module.list_search_app())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "list boom" in res["message"]
_set_request_json(monkeypatch, module, {"search_id": "search-1"})
monkeypatch.setattr(module.SearchService, "accessible4deletion", lambda _search_id, _user_id: False)
res = _run(module.rm())
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
assert "authorization" in res["message"].lower()
monkeypatch.setattr(module.SearchService, "accessible4deletion", lambda _search_id, _user_id: True)
monkeypatch.setattr(module.SearchService, "delete_by_id", lambda _search_id: False)
res = _run(module.rm())
assert res["code"] == module.RetCode.DATA_ERROR
assert "failed to delete" in res["message"].lower()
monkeypatch.setattr(module.SearchService, "delete_by_id", lambda _search_id: True)
res = _run(module.rm())
assert res["code"] == 0
assert res["data"] is True
def _raise_delete(_search_id):
raise RuntimeError("rm boom")
monkeypatch.setattr(module.SearchService, "delete_by_id", _raise_delete)
res = _run(module.rm())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "rm boom" in res["message"]

View File

@ -0,0 +1,322 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _ExprField:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return (self.name, other)
class _DummyAPITokenModel:
tenant_id = _ExprField("tenant_id")
token = _ExprField("token")
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
def _load_system_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.login_required = lambda fn: fn
apps_mod.current_user = SimpleNamespace(id="user-1")
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
settings_mod = ModuleType("common.settings")
settings_mod.docStoreConn = SimpleNamespace(health=lambda: {"type": "doc", "status": "green"})
settings_mod.STORAGE_IMPL = SimpleNamespace(health=lambda: True)
settings_mod.STORAGE_IMPL_TYPE = "MINIO"
settings_mod.DATABASE_TYPE = "MYSQL"
settings_mod.REGISTER_ENABLED = True
common_pkg.settings = settings_mod
monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
versions_mod = ModuleType("common.versions")
versions_mod.get_ragflow_version = lambda: "0.0.0-unit"
monkeypatch.setitem(sys.modules, "common.versions", versions_mod)
time_utils_mod = ModuleType("common.time_utils")
time_utils_mod.current_timestamp = lambda: 111
time_utils_mod.datetime_format = lambda _dt: "2026-01-01 00:00:00"
monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
api_utils_mod.get_json_result = lambda data=None, message="success", code=0: {
"code": code,
"message": message,
"data": data,
}
api_utils_mod.get_data_error_result = lambda message="", code=102, data=None: {
"code": code,
"message": message,
"data": data,
}
api_utils_mod.server_error_response = lambda exc: {
"code": 100,
"message": repr(exc),
"data": None,
}
api_utils_mod.generate_confirmation_token = lambda: "ragflow-abcdefghijklmnopqrstuvwxyz0123456789"
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
api_service_mod = ModuleType("api.db.services.api_service")
api_service_mod.APITokenService = SimpleNamespace(
save=lambda **_kwargs: True,
query=lambda **_kwargs: [],
filter_update=lambda *_args, **_kwargs: True,
filter_delete=lambda *_args, **_kwargs: True,
)
monkeypatch.setitem(sys.modules, "api.db.services.api_service", api_service_mod)
kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
kb_service_mod.KnowledgebaseService = SimpleNamespace(get_by_id=lambda _kb_id: True)
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
user_service_mod = ModuleType("api.db.services.user_service")
user_service_mod.UserTenantService = SimpleNamespace(
query=lambda **_kwargs: [SimpleNamespace(role="owner", tenant_id="tenant-1")]
)
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
db_models_mod = ModuleType("api.db.db_models")
db_models_mod.APIToken = _DummyAPITokenModel
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
rag_pkg = ModuleType("rag")
rag_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
rag_utils_pkg = ModuleType("rag.utils")
rag_utils_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg)
redis_mod = ModuleType("rag.utils.redis_conn")
redis_mod.REDIS_CONN = SimpleNamespace(
health=lambda: True,
smembers=lambda *_args, **_kwargs: set(),
zrangebyscore=lambda *_args, **_kwargs: [],
)
monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod)
health_utils_mod = ModuleType("api.utils.health_utils")
health_utils_mod.run_health_checks = lambda: ({"status": "ok"}, True)
health_utils_mod.get_oceanbase_status = lambda: {"status": "alive"}
monkeypatch.setitem(sys.modules, "api.utils.health_utils", health_utils_mod)
quart_mod = ModuleType("quart")
quart_mod.jsonify = lambda payload: payload
monkeypatch.setitem(sys.modules, "quart", quart_mod)
module_path = repo_root / "api" / "apps" / "system_app.py"
spec = importlib.util.spec_from_file_location("test_system_routes_unit_module", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, "test_system_routes_unit_module", module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_status_branch_matrix_unit(monkeypatch):
module = _load_system_module(monkeypatch)
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(health=lambda: {"type": "es", "status": "green"}))
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(health=lambda: True))
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: True)
monkeypatch.setattr(module.REDIS_CONN, "health", lambda: True)
monkeypatch.setattr(module.REDIS_CONN, "smembers", lambda _key: {"executor-1"})
monkeypatch.setattr(module.REDIS_CONN, "zrangebyscore", lambda *_args, **_kwargs: ['{"beat": 1}'])
res = module.status()
assert res["code"] == 0
assert res["data"]["doc_engine"]["status"] == "green"
assert res["data"]["storage"]["status"] == "green"
assert res["data"]["database"]["status"] == "green"
assert res["data"]["redis"]["status"] == "green"
assert res["data"]["task_executor_heartbeats"]["executor-1"][0]["beat"] == 1
monkeypatch.setattr(
module.settings,
"docStoreConn",
SimpleNamespace(health=lambda: (_ for _ in ()).throw(RuntimeError("doc down"))),
)
monkeypatch.setattr(
module.settings,
"STORAGE_IMPL",
SimpleNamespace(health=lambda: (_ for _ in ()).throw(RuntimeError("storage down"))),
)
monkeypatch.setattr(
module.KnowledgebaseService,
"get_by_id",
lambda _kb_id: (_ for _ in ()).throw(RuntimeError("db down")),
)
monkeypatch.setattr(module.REDIS_CONN, "health", lambda: False)
monkeypatch.setattr(module.REDIS_CONN, "smembers", lambda _key: (_ for _ in ()).throw(RuntimeError("hb down")))
res = module.status()
assert res["code"] == 0
assert res["data"]["doc_engine"]["status"] == "red"
assert "doc down" in res["data"]["doc_engine"]["error"]
assert res["data"]["storage"]["status"] == "red"
assert "storage down" in res["data"]["storage"]["error"]
assert res["data"]["database"]["status"] == "red"
assert "db down" in res["data"]["database"]["error"]
assert res["data"]["redis"]["status"] == "red"
assert "Lost connection!" in res["data"]["redis"]["error"]
assert res["data"]["task_executor_heartbeats"] == {}
@pytest.mark.p2
def test_healthz_and_oceanbase_status_matrix_unit(monkeypatch):
module = _load_system_module(monkeypatch)
monkeypatch.setattr(module, "run_health_checks", lambda: ({"status": "ok"}, True))
payload, status = module.healthz()
assert status == 200
assert payload["status"] == "ok"
monkeypatch.setattr(module, "run_health_checks", lambda: ({"status": "degraded"}, False))
payload, status = module.healthz()
assert status == 500
assert payload["status"] == "degraded"
monkeypatch.setattr(module, "get_oceanbase_status", lambda: {"status": "alive", "latency_ms": 8})
res = module.oceanbase_status()
assert res["code"] == 0
assert res["data"]["status"] == "alive"
monkeypatch.setattr(module, "get_oceanbase_status", lambda: (_ for _ in ()).throw(RuntimeError("ocean boom")))
res = module.oceanbase_status()
assert res["code"] == 500
assert res["data"]["status"] == "error"
assert "ocean boom" in res["data"]["message"]
@pytest.mark.p2
def test_system_token_routes_matrix_unit(monkeypatch):
module = _load_system_module(monkeypatch)
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
res = module.new_token()
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="owner", tenant_id="tenant-1")])
monkeypatch.setattr(module.APITokenService, "save", lambda **_kwargs: False)
res = module.new_token()
assert res["message"] == "Fail to new a dialog!"
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("tenant query boom")))
res = module.new_token()
assert res["code"] == 100
assert "tenant query boom" in res["message"]
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
res = module.token_list()
assert res["message"] == "Tenant not found!"
class _Token:
def __init__(self, token, beta):
self.token = token
self.beta = beta
def to_dict(self):
return {"token": self.token, "beta": self.beta}
filter_updates = []
monkeypatch.setattr(module, "generate_confirmation_token", lambda: "ragflow-abcdefghijklmnopqrstuvwxyz0123456789")
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="owner", tenant_id="tenant-9")])
monkeypatch.setattr(module.APITokenService, "query", lambda **_kwargs: [_Token("tok-1", ""), _Token("tok-2", "beta-2")])
monkeypatch.setattr(module.APITokenService, "filter_update", lambda conds, payload: filter_updates.append((conds, payload)))
res = module.token_list()
assert res["code"] == 0
assert len(res["data"]) == 2
assert len(res["data"][0]["beta"]) == 32
assert res["data"][1]["beta"] == "beta-2"
assert len(filter_updates) == 1
monkeypatch.setattr(
module.APITokenService,
"query",
lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("token list boom")),
)
res = module.token_list()
assert res["code"] == 100
assert "token list boom" in res["message"]
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
res = module.rm("tok-1")
assert res["message"] == "Tenant not found!"
deleted = []
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="owner", tenant_id="tenant-3")])
monkeypatch.setattr(module.APITokenService, "filter_delete", lambda conds: deleted.append(conds))
res = module.rm("tok-1")
assert res["code"] == 0
assert res["data"] is True
assert deleted
monkeypatch.setattr(
module.APITokenService,
"filter_delete",
lambda _conds: (_ for _ in ()).throw(RuntimeError("delete boom")),
)
res = module.rm("tok-1")
assert res["code"] == 100
assert "delete boom" in res["message"]
@pytest.mark.p2
def test_get_config_returns_register_enabled_unit(monkeypatch):
module = _load_system_module(monkeypatch)
monkeypatch.setattr(module.settings, "REGISTER_ENABLED", False)
res = module.get_config()
assert res["code"] == 0
assert res["data"]["registerEnabled"] is False

View File

@ -0,0 +1,318 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _ExprField:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return (self.name, other)
class _Invitee:
def __init__(self, user_id="invitee-1", email="invitee@example.com"):
self.id = user_id
self.email = email
def to_dict(self):
return {
"id": self.id,
"avatar": "avatar-url",
"email": self.email,
"nickname": "Invitee",
"password": "ignored",
}
def _run(coro):
return asyncio.run(coro)
def _set_request_json(monkeypatch, module, payload):
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))
def _load_tenant_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="tenant-1", email="owner@example.com")
apps_mod.login_required = lambda fn: fn
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
db_mod = ModuleType("api.db")
db_mod.UserTenantRole = SimpleNamespace(NORMAL="normal", OWNER="owner", INVITE="invite")
monkeypatch.setitem(sys.modules, "api.db", db_mod)
db_models_mod = ModuleType("api.db.db_models")
db_models_mod.UserTenant = type(
"UserTenant",
(),
{
"tenant_id": _ExprField("tenant_id"),
"user_id": _ExprField("user_id"),
},
)
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
user_service_mod = ModuleType("api.db.services.user_service")
class _UserTenantService:
@staticmethod
def get_by_tenant_id(_tenant_id):
return []
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def save(**_kwargs):
return True
@staticmethod
def filter_delete(_conditions):
return True
@staticmethod
def get_tenants_by_user_id(_user_id):
return []
@staticmethod
def filter_update(_conditions, _payload):
return True
class _UserService:
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def get_by_id(_user_id):
return False, None
user_service_mod.UserTenantService = _UserTenantService
user_service_mod.UserService = _UserService
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "message": message, "data": data}
api_utils_mod.get_data_error_result = lambda message="": {"code": 102, "message": message, "data": False}
api_utils_mod.server_error_response = lambda exc: {"code": 100, "message": repr(exc), "data": False}
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
api_utils_mod.get_request_json = lambda: _AwaitableValue({})
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
web_utils_mod = ModuleType("api.utils.web_utils")
web_utils_mod.send_invite_email = lambda **_kwargs: {"ok": True}
monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
constants_mod = ModuleType("common.constants")
constants_mod.RetCode = SimpleNamespace(AUTHENTICATION_ERROR=401, SERVER_ERROR=500, DATA_ERROR=102)
constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value=1))
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
misc_utils_mod = ModuleType("common.misc_utils")
misc_utils_mod.get_uuid = lambda: "uuid-1"
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
time_utils_mod = ModuleType("common.time_utils")
time_utils_mod.delta_seconds = lambda _value: 0
monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod)
settings_mod = ModuleType("common.settings")
settings_mod.MAIL_FRONTEND_URL = "https://frontend.example/invite"
monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
common_pkg.settings = settings_mod
sys.modules.pop("test_tenant_app_unit_module", None)
module_path = repo_root / "api" / "apps" / "tenant_app.py"
spec = importlib.util.spec_from_file_location("test_tenant_app_unit_module", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, "test_tenant_app_unit_module", module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_user_list_auth_success_exception_matrix_unit(monkeypatch):
module = _load_tenant_module(monkeypatch)
module.current_user.id = "other-user"
res = module.user_list("tenant-1")
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
assert res["message"] == "No authorization.", res
module.current_user.id = "tenant-1"
monkeypatch.setattr(
module.UserTenantService,
"get_by_tenant_id",
lambda _tenant_id: [{"id": "u1", "update_date": "2024-01-01 00:00:00"}],
)
monkeypatch.setattr(module, "delta_seconds", lambda _value: 42)
res = module.user_list("tenant-1")
assert res["code"] == 0, res
assert res["data"][0]["delta_seconds"] == 42, res
monkeypatch.setattr(module.UserTenantService, "get_by_tenant_id", lambda _tenant_id: (_ for _ in ()).throw(RuntimeError("list boom")))
res = module.user_list("tenant-1")
assert res["code"] == 100, res
assert "list boom" in res["message"], res
@pytest.mark.p2
def test_create_invite_role_and_email_failure_matrix_unit(monkeypatch):
module = _load_tenant_module(monkeypatch)
module.current_user.id = "other-user"
_set_request_json(monkeypatch, module, {"email": "invitee@example.com"})
res = _run(module.create("tenant-1"))
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
assert res["message"] == "No authorization.", res
module.current_user.id = "tenant-1"
monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [])
res = _run(module.create("tenant-1"))
assert res["message"] == "User not found.", res
invitee = _Invitee()
monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [invitee])
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role=module.UserTenantRole.NORMAL)])
res = _run(module.create("tenant-1"))
assert "already in the team." in res["message"], res
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role=module.UserTenantRole.OWNER)])
res = _run(module.create("tenant-1"))
assert "owner of the team." in res["message"], res
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="strange-role")])
res = _run(module.create("tenant-1"))
assert "role: strange-role is invalid." in res["message"], res
saved = []
scheduled = []
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module.UserTenantService, "save", lambda **kwargs: saved.append(kwargs) or True)
monkeypatch.setattr(module.UserService, "get_by_id", lambda _user_id: (True, SimpleNamespace(nickname="Inviter Nick")))
monkeypatch.setattr(module, "send_invite_email", lambda **kwargs: kwargs)
monkeypatch.setattr(module.asyncio, "create_task", lambda payload: scheduled.append(payload) or SimpleNamespace())
res = _run(module.create("tenant-1"))
assert res["code"] == 0, res
assert saved and saved[-1]["role"] == module.UserTenantRole.INVITE, saved
assert scheduled and scheduled[-1]["inviter"] == "Inviter Nick", scheduled
assert sorted(res["data"].keys()) == ["avatar", "email", "id", "nickname"], res
monkeypatch.setattr(module.asyncio, "create_task", lambda _payload: (_ for _ in ()).throw(RuntimeError("send boom")))
res = _run(module.create("tenant-1"))
assert res["code"] == module.RetCode.SERVER_ERROR, res
assert "Failed to send invite email." in res["message"], res
@pytest.mark.p2
def test_rm_and_tenant_list_matrix_unit(monkeypatch):
module = _load_tenant_module(monkeypatch)
module.current_user.id = "outsider"
res = module.rm("tenant-1", "user-2")
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
assert res["message"] == "No authorization.", res
module.current_user.id = "tenant-1"
deleted = []
monkeypatch.setattr(module.UserTenantService, "filter_delete", lambda conditions: deleted.append(conditions) or True)
res = module.rm("tenant-1", "user-2")
assert res["code"] == 0, res
assert res["data"] is True, res
assert deleted, "filter_delete should be called"
monkeypatch.setattr(module.UserTenantService, "filter_delete", lambda _conditions: (_ for _ in ()).throw(RuntimeError("rm boom")))
res = module.rm("tenant-1", "user-2")
assert res["code"] == 100, res
assert "rm boom" in res["message"], res
monkeypatch.setattr(
module.UserTenantService,
"get_tenants_by_user_id",
lambda _user_id: [{"id": "tenant-1", "update_date": "2024-01-01 00:00:00"}],
)
monkeypatch.setattr(module, "delta_seconds", lambda _value: 9)
res = module.tenant_list()
assert res["code"] == 0, res
assert res["data"][0]["delta_seconds"] == 9, res
monkeypatch.setattr(module.UserTenantService, "get_tenants_by_user_id", lambda _user_id: (_ for _ in ()).throw(RuntimeError("tenant boom")))
res = module.tenant_list()
assert res["code"] == 100, res
assert "tenant boom" in res["message"], res
@pytest.mark.p2
def test_agree_success_and_exception_unit(monkeypatch):
module = _load_tenant_module(monkeypatch)
calls = []
monkeypatch.setattr(module.UserTenantService, "filter_update", lambda conditions, payload: calls.append((conditions, payload)) or True)
res = module.agree("tenant-1")
assert res["code"] == 0, res
assert res["data"] is True, res
assert calls and calls[-1][1]["role"] == module.UserTenantRole.NORMAL
monkeypatch.setattr(module.UserTenantService, "filter_update", lambda _conditions, _payload: (_ for _ in ()).throw(RuntimeError("agree boom")))
res = module.agree("tenant-1")
assert res["code"] == 100, res
assert "agree boom" in res["message"], res

File diff suppressed because it is too large Load Diff