mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-03 16:57:48 +08:00
tests: improve RAGFlow coverage based on Codecov report (#13219)
### What problem does this PR solve? Codecov’s coverage report shows that several RAGFlow code paths are currently untested or under-tested. This makes it easier for regressions to slip in during refactors and feature work. This PR adds targeted automated tests to cover the files and branches highlighted by Codecov, improving confidence in core behavior while keeping runtime functionality unchanged. ### Type of change - [x] Other (please describe): Test coverage improvement (adds/extends unit and integration tests to address Codecov-reported gaps)
This commit is contained in:
@ -125,7 +125,8 @@ class TestDatasetUpdate:
|
||||
@pytest.mark.p1
|
||||
@given(name=valid_names())
|
||||
@example("a" * 128)
|
||||
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
# Network-bound API call; disable Hypothesis deadline to avoid flaky timeouts.
|
||||
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
|
||||
def test_name(self, HttpApiAuth, add_dataset_func, name):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"name": name}
|
||||
|
||||
@ -77,6 +77,15 @@ class _StubResponse:
|
||||
self.headers = _StubHeaders()
|
||||
|
||||
|
||||
class _DummyUploadFile:
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.saved_path = None
|
||||
|
||||
async def save(self, path):
|
||||
self.saved_path = path
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
@ -96,6 +105,16 @@ async def _collect_stream(body):
|
||||
return items
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_session_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
common_pkg = ModuleType("common")
|
||||
@ -1125,3 +1144,355 @@ def test_searchbots_retrieval_test_embedded_matrix_unit(monkeypatch):
|
||||
assert retrieval_capture["rank_feature"] == ["label-1"]
|
||||
assert retrieval_capture["rerank_mdl"] is not None
|
||||
assert any(call[1] == module.LLMType.EMBEDDING.value and call[3].get("llm_name") == "embd-model" for call in llm_calls)
|
||||
|
||||
llm_calls.clear()
|
||||
|
||||
async def _fake_keyword_extraction(_chat_mdl, question):
|
||||
return f"-{question}-keywords"
|
||||
|
||||
async def _fake_kg_retrieval(question, tenant_ids, kb_ids, _embd_mdl, _chat_mdl):
|
||||
return {
|
||||
"id": "kg-chunk",
|
||||
"question": question,
|
||||
"tenant_ids": tenant_ids,
|
||||
"kb_ids": kb_ids,
|
||||
"content_with_weight": 1,
|
||||
"vector": [0.5],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(module, "keyword_extraction", _fake_keyword_extraction)
|
||||
monkeypatch.setattr(module.settings, "kg_retriever", SimpleNamespace(retrieval=_fake_kg_retrieval))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue(
|
||||
{
|
||||
"kb_id": "kb-1",
|
||||
"question": "keyword-q",
|
||||
"rerank_id": "manual-reranker",
|
||||
"keyword": True,
|
||||
"use_kg": True,
|
||||
}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_by_id",
|
||||
lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-model")),
|
||||
)
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["chunks"][0]["id"] == "kg-chunk"
|
||||
assert all("vector" not in chunk for chunk in res["data"]["chunks"])
|
||||
assert any(call[1] == module.LLMType.RERANK.value for call in llm_calls)
|
||||
|
||||
async def _raise_not_found(*_args, **_kwargs):
|
||||
raise RuntimeError("x not_found y")
|
||||
|
||||
monkeypatch.setattr(module.settings, "retriever", SimpleNamespace(retrieval=_raise_not_found))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"kb_id": "kb-1", "question": "q"}),
|
||||
)
|
||||
res = _run(handler())
|
||||
assert res["message"] == "No chunk found! Check the chunk status please!"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_searchbots_related_questions_embedded_matrix_unit(monkeypatch):
|
||||
module = _load_session_module(monkeypatch)
|
||||
handler = inspect.unwrap(module.related_questions_embedded)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer"}))
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Authorization is not valid!"
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer bad"}))
|
||||
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [])
|
||||
res = _run(handler())
|
||||
assert "API key is invalid" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer ok"}))
|
||||
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="")])
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"question": "q"}))
|
||||
res = _run(handler())
|
||||
assert res["message"] == "permission denined."
|
||||
|
||||
captured = {}
|
||||
|
||||
class _FakeChatBundle:
|
||||
async def async_chat(self, prompt, messages, options):
|
||||
captured["prompt"] = prompt
|
||||
captured["messages"] = messages
|
||||
captured["options"] = options
|
||||
return "1. Alpha\n2. Beta\nignored"
|
||||
|
||||
def _fake_bundle(*args, **_kwargs):
|
||||
captured["bundle_args"] = args
|
||||
return _FakeChatBundle()
|
||||
|
||||
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"question": "solar", "search_id": "search-1"}),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.SearchService,
|
||||
"get_detail",
|
||||
lambda _search_id: {"search_config": {"chat_id": "chat-x", "llm_setting": {"temperature": 0.2}}},
|
||||
)
|
||||
monkeypatch.setattr(module, "LLMBundle", _fake_bundle)
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == ["Alpha", "Beta"]
|
||||
assert captured["bundle_args"] == ("tenant-1", module.LLMType.CHAT, "chat-x")
|
||||
assert captured["options"] == {"temperature": 0.2}
|
||||
assert "Keywords: solar" in captured["messages"][0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_searchbots_detail_share_embedded_matrix_unit(monkeypatch):
|
||||
module = _load_session_module(monkeypatch)
|
||||
handler = inspect.unwrap(module.detail_share_embedded)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer"}, args={"search_id": "s-1"}))
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Authorization is not valid!"
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer bad"}, args={"search_id": "s-1"}))
|
||||
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [])
|
||||
res = _run(handler())
|
||||
assert "API key is invalid" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer ok"}, args={"search_id": "s-1"}))
|
||||
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="")])
|
||||
res = _run(handler())
|
||||
assert res["message"] == "permission denined."
|
||||
|
||||
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-a")])
|
||||
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [])
|
||||
res = _run(handler())
|
||||
assert res["code"] == module.RetCode.OPERATING_ERROR
|
||||
assert "Has no permission for this operation." in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [SimpleNamespace(id="s-1")])
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: None)
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Can't find this Search App!"
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"id": "s-1", "name": "search-app"})
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "s-1"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_searchbots_mindmap_embedded_matrix_unit(monkeypatch):
|
||||
module = _load_session_module(monkeypatch)
|
||||
handler = inspect.unwrap(module.mindmap)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer"}))
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Authorization is not valid!"
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer bad"}))
|
||||
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [])
|
||||
res = _run(handler())
|
||||
assert "API key is invalid" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(headers={"Authorization": "Bearer ok"}))
|
||||
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"question": "q", "kb_ids": ["kb-1"]}))
|
||||
|
||||
captured = {}
|
||||
|
||||
async def _gen_ok(question, kb_ids, tenant_id, search_config):
|
||||
captured["params"] = (question, kb_ids, tenant_id, search_config)
|
||||
return {"nodes": [question]}
|
||||
|
||||
monkeypatch.setattr(module, "gen_mindmap", _gen_ok)
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == {"nodes": ["q"]}
|
||||
assert captured["params"] == ("q", ["kb-1"], "tenant-1", {})
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"question": "q2", "kb_ids": ["kb-1"], "search_id": "search-1"}),
|
||||
)
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"mode": "graph"}})
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
assert captured["params"] == ("q2", ["kb-1"], "tenant-1", {"mode": "graph"})
|
||||
|
||||
async def _gen_error(*_args, **_kwargs):
|
||||
return {"error": "mindmap boom"}
|
||||
|
||||
monkeypatch.setattr(module, "gen_mindmap", _gen_error)
|
||||
res = _run(handler())
|
||||
assert "mindmap boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_sequence2txt_embedded_validation_and_stream_matrix_unit(monkeypatch):
|
||||
module = _load_session_module(monkeypatch)
|
||||
handler = inspect.unwrap(module.sequence2txt)
|
||||
monkeypatch.setattr(module, "Response", _StubResponse)
|
||||
monkeypatch.setattr(module.tempfile, "mkstemp", lambda suffix: (11, f"/tmp/audio{suffix}"))
|
||||
monkeypatch.setattr(module.os, "close", lambda _fd: None)
|
||||
|
||||
def _set_request(form, files):
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"request",
|
||||
SimpleNamespace(form=_AwaitableValue(form), files=_AwaitableValue(files)),
|
||||
)
|
||||
|
||||
_set_request({"stream": "false"}, {})
|
||||
res = _run(handler("tenant-1"))
|
||||
assert "Missing 'file' in multipart form-data" in res["message"]
|
||||
|
||||
_set_request({"stream": "false"}, {"file": _DummyUploadFile("bad.txt")})
|
||||
res = _run(handler("tenant-1"))
|
||||
assert "Unsupported audio format: .txt" in res["message"]
|
||||
|
||||
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [])
|
||||
res = _run(handler("tenant-1"))
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "asr_id": ""}])
|
||||
res = _run(handler("tenant-1"))
|
||||
assert res["message"] == "No default ASR model is set"
|
||||
|
||||
class _SyncASR:
|
||||
def transcription(self, _path):
|
||||
return "transcribed text"
|
||||
|
||||
def stream_transcription(self, _path):
|
||||
return []
|
||||
|
||||
_set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")})
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "asr_id": "asr-x"}])
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncASR())
|
||||
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup fail")))
|
||||
res = _run(handler("tenant-1"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["text"] == "transcribed text"
|
||||
|
||||
class _StreamASR:
|
||||
def transcription(self, _path):
|
||||
return ""
|
||||
|
||||
def stream_transcription(self, _path):
|
||||
yield {"event": "partial", "text": "hello"}
|
||||
|
||||
_set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")})
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _StreamASR())
|
||||
monkeypatch.setattr(module.os, "remove", lambda _path: None)
|
||||
resp = _run(handler("tenant-1"))
|
||||
assert isinstance(resp, _StubResponse)
|
||||
assert resp.content_type == "text/event-stream"
|
||||
chunks = _run(_collect_stream(resp.body))
|
||||
assert any('"event": "partial"' in chunk for chunk in chunks)
|
||||
|
||||
class _ErrorASR:
|
||||
def transcription(self, _path):
|
||||
return ""
|
||||
|
||||
def stream_transcription(self, _path):
|
||||
raise RuntimeError("stream asr boom")
|
||||
|
||||
_set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")})
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _ErrorASR())
|
||||
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup boom")))
|
||||
resp = _run(handler("tenant-1"))
|
||||
chunks = _run(_collect_stream(resp.body))
|
||||
assert any("stream asr boom" in chunk for chunk in chunks)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_tts_embedded_stream_and_error_matrix_unit(monkeypatch):
|
||||
module = _load_session_module(monkeypatch)
|
||||
handler = inspect.unwrap(module.tts)
|
||||
monkeypatch.setattr(module, "Response", _StubResponse)
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"text": "A。B"}))
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [])
|
||||
res = _run(handler("tenant-1"))
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "tts_id": ""}])
|
||||
res = _run(handler("tenant-1"))
|
||||
assert res["message"] == "No default TTS model is set"
|
||||
|
||||
class _TTSOk:
|
||||
def tts(self, txt):
|
||||
if not txt:
|
||||
return []
|
||||
yield f"chunk-{txt}".encode("utf-8")
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _tid: [{"tenant_id": "tenant-1", "tts_id": "tts-x"}])
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
|
||||
resp = _run(handler("tenant-1"))
|
||||
assert resp.mimetype == "audio/mpeg"
|
||||
assert resp.headers.get("Cache-Control") == "no-cache"
|
||||
assert resp.headers.get("Connection") == "keep-alive"
|
||||
assert resp.headers.get("X-Accel-Buffering") == "no"
|
||||
chunks = _run(_collect_stream(resp.body))
|
||||
assert any("chunk-A" in chunk for chunk in chunks)
|
||||
assert any("chunk-B" in chunk for chunk in chunks)
|
||||
|
||||
class _TTSErr:
|
||||
def tts(self, _txt):
|
||||
raise RuntimeError("tts boom")
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSErr())
|
||||
resp = _run(handler("tenant-1"))
|
||||
chunks = _run(_collect_stream(resp.body))
|
||||
assert any('"code": 500' in chunk and "**ERROR**: tts boom" in chunk for chunk in chunks)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_build_reference_chunks_metadata_matrix_unit(monkeypatch):
|
||||
module = _load_session_module(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module, "chunks_format", lambda _reference: [{"dataset_id": "kb-1", "document_id": "doc-1"}])
|
||||
res = module._build_reference_chunks([], include_metadata=False)
|
||||
assert res == [{"dataset_id": "kb-1", "document_id": "doc-1"}]
|
||||
|
||||
monkeypatch.setattr(module, "chunks_format", lambda _reference: [{"dataset_id": "kb-1"}, {"document_id": "doc-2"}])
|
||||
res = module._build_reference_chunks([], include_metadata=True)
|
||||
assert all("document_metadata" not in chunk for chunk in res)
|
||||
|
||||
monkeypatch.setattr(module, "chunks_format", lambda _reference: [{"dataset_id": "kb-1", "document_id": "doc-1"}])
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_metadata_for_documents", lambda _doc_ids, _kb_id: {"doc-1": {"author": "alice"}})
|
||||
res = module._build_reference_chunks([], include_metadata=True, metadata_fields=[1, None])
|
||||
assert "document_metadata" not in res[0]
|
||||
|
||||
source_chunks = [
|
||||
{"dataset_id": "kb-1", "document_id": "doc-1"},
|
||||
{"dataset_id": "kb-2", "document_id": "doc-2"},
|
||||
{"dataset_id": "kb-1", "document_id": "doc-3"},
|
||||
{"dataset_id": "kb-1", "document_id": None},
|
||||
]
|
||||
monkeypatch.setattr(module, "chunks_format", lambda _reference: [dict(chunk) for chunk in source_chunks])
|
||||
|
||||
def _get_metadata(_doc_ids, kb_id):
|
||||
if kb_id == "kb-1":
|
||||
return {"doc-1": {"author": "alice", "year": 2024}}
|
||||
if kb_id == "kb-2":
|
||||
return {"doc-2": {"author": "bob", "tag": "rag"}}
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_metadata_for_documents", _get_metadata)
|
||||
res = module._build_reference_chunks([], include_metadata=True, metadata_fields=["author", "missing", 3])
|
||||
assert res[0]["document_metadata"] == {"author": "alice"}
|
||||
assert res[1]["document_metadata"] == {"author": "bob"}
|
||||
assert "document_metadata" not in res[2]
|
||||
assert "document_metadata" not in res[3]
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow
|
||||
from ragflow_sdk.modules.agent import Agent
|
||||
from ragflow_sdk.modules.session import Session
|
||||
|
||||
|
||||
class _DummyResponse:
|
||||
@ -27,6 +28,16 @@ class _DummyResponse:
|
||||
return self._payload
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_agents_success_and_error(monkeypatch):
|
||||
client = RAGFlow("token", "http://localhost:9380")
|
||||
@ -122,3 +133,84 @@ def test_delete_agent_success_and_error(monkeypatch):
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.delete_agent("agent-1")
|
||||
assert "delete boom" in str(exception_info.value), str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_and_dsl_default_initialization():
|
||||
client = RAGFlow("token", "http://localhost:9380")
|
||||
|
||||
agent = Agent(client, {"id": "agent-1", "title": "Agent One"})
|
||||
assert agent.id == "agent-1"
|
||||
assert agent.avatar is None
|
||||
assert agent.canvas_type is None
|
||||
assert agent.description is None
|
||||
assert agent.dsl is None
|
||||
|
||||
dsl = Agent.Dsl(client, {})
|
||||
assert dsl.answer == []
|
||||
assert "begin" in dsl.components
|
||||
assert dsl.components["begin"]["obj"]["component_name"] == "Begin"
|
||||
assert dsl.graph["nodes"][0]["id"] == "begin"
|
||||
assert dsl.history == []
|
||||
assert dsl.messages == []
|
||||
assert dsl.path == []
|
||||
assert dsl.reference == []
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agent_session_methods_success_and_error_paths(monkeypatch):
|
||||
client = RAGFlow("token", "http://localhost:9380")
|
||||
agent = Agent(client, {"id": "agent-1"})
|
||||
calls = {"post": [], "get": [], "rm": []}
|
||||
|
||||
def _ok_post(path, json=None, stream=False, files=None):
|
||||
calls["post"].append((path, json, stream, files))
|
||||
return _DummyResponse({"code": 0, "data": {"id": "session-1", "agent_id": "agent-1", "name": "one"}})
|
||||
|
||||
def _ok_get(path, params=None):
|
||||
calls["get"].append((path, params))
|
||||
return _DummyResponse(
|
||||
{
|
||||
"code": 0,
|
||||
"data": [
|
||||
{"id": "session-1", "agent_id": "agent-1", "name": "one"},
|
||||
{"id": "session-2", "agent_id": "agent-1", "name": "two"},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def _ok_rm(path, payload):
|
||||
calls["rm"].append((path, payload))
|
||||
return _DummyResponse({"code": 0, "message": "ok"})
|
||||
|
||||
monkeypatch.setattr(agent, "post", _ok_post)
|
||||
monkeypatch.setattr(agent, "get", _ok_get)
|
||||
monkeypatch.setattr(agent, "rm", _ok_rm)
|
||||
|
||||
session = agent.create_session(name="session-name")
|
||||
assert isinstance(session, Session), str(session)
|
||||
assert session.id == "session-1", str(session)
|
||||
assert calls["post"][-1][0] == "/agents/agent-1/sessions"
|
||||
assert calls["post"][-1][1] == {"name": "session-name"}
|
||||
|
||||
sessions = agent.list_sessions(page=2, page_size=5, orderby="create_time", desc=False, id="session-1")
|
||||
assert len(sessions) == 2, str(sessions)
|
||||
assert all(isinstance(item, Session) for item in sessions), str(sessions)
|
||||
assert calls["get"][-1][0] == "/agents/agent-1/sessions"
|
||||
assert calls["get"][-1][1]["page"] == 2
|
||||
assert calls["get"][-1][1]["id"] == "session-1"
|
||||
|
||||
agent.delete_sessions(ids=["session-1", "session-2"])
|
||||
assert calls["rm"][-1] == ("/agents/agent-1/sessions", {"ids": ["session-1", "session-2"]})
|
||||
|
||||
monkeypatch.setattr(agent, "post", lambda *_args, **_kwargs: _DummyResponse({"code": 1, "message": "create failed"}))
|
||||
with pytest.raises(Exception, match="create failed"):
|
||||
agent.create_session(name="bad")
|
||||
|
||||
monkeypatch.setattr(agent, "get", lambda *_args, **_kwargs: _DummyResponse({"code": 2, "message": "list failed"}))
|
||||
with pytest.raises(Exception, match="list failed"):
|
||||
agent.list_sessions()
|
||||
|
||||
monkeypatch.setattr(agent, "rm", lambda *_args, **_kwargs: _DummyResponse({"code": 3, "message": "delete failed"}))
|
||||
with pytest.raises(Exception, match="delete failed"):
|
||||
agent.delete_sessions(ids=["session-1"])
|
||||
|
||||
@ -17,6 +17,28 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from configs import SESSION_WITH_CHAT_NAME_LIMIT
|
||||
from ragflow_sdk import RAGFlow
|
||||
from ragflow_sdk.modules.session import Session
|
||||
|
||||
|
||||
class _DummyStreamResponse:
|
||||
def __init__(self, lines):
|
||||
self._lines = lines
|
||||
|
||||
def iter_lines(self, decode_unicode=True):
|
||||
del decode_unicode
|
||||
for line in self._lines:
|
||||
yield line
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("clear_session_with_chat_assistants")
|
||||
@ -74,3 +96,72 @@ class TestSessionWithChatAssistantCreate:
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
chat_assistant.create_session(name="valid_name")
|
||||
assert "You do not own the assistant" in str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_module_streaming_and_helper_paths_unit(monkeypatch):
|
||||
client = RAGFlow("token", "http://localhost:9380")
|
||||
chat_session = Session(client, {"id": "session-chat", "chat_id": "chat-1"})
|
||||
chat_done_session = Session(client, {"id": "session-chat-done", "chat_id": "chat-1"})
|
||||
agent_session = Session(client, {"id": "session-agent", "agent_id": "agent-1"})
|
||||
calls = []
|
||||
|
||||
chat_stream = _DummyStreamResponse(
|
||||
[
|
||||
"",
|
||||
"data: {bad json}",
|
||||
'data: {"event":"workflow_started","data":{"content":"skip"}}',
|
||||
'{"data":{"answer":"chat-answer","reference":{"chunks":[{"id":"chunk-1"}]}}}',
|
||||
'data: {"data": true}',
|
||||
"data: [DONE]",
|
||||
]
|
||||
)
|
||||
agent_stream = _DummyStreamResponse(
|
||||
[
|
||||
"data: {bad json}",
|
||||
'data: {"event":"message","data":{"content":"agent-answer"}}',
|
||||
'data: {"event":"message_end","data":{"content":"done"}}',
|
||||
]
|
||||
)
|
||||
|
||||
def _chat_post(path, json=None, stream=False, files=None):
|
||||
calls.append(("chat", path, json, stream, files))
|
||||
return chat_stream
|
||||
|
||||
def _agent_post(path, json=None, stream=False, files=None):
|
||||
calls.append(("agent", path, json, stream, files))
|
||||
return agent_stream
|
||||
|
||||
monkeypatch.setattr(chat_session, "post", _chat_post)
|
||||
monkeypatch.setattr(
|
||||
chat_done_session,
|
||||
"post",
|
||||
lambda *_args, **_kwargs: _DummyStreamResponse(
|
||||
['{"data":{"answer":"chat-done","reference":{"chunks":[]}}}', "data: [DONE]"]
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(agent_session, "post", _agent_post)
|
||||
|
||||
chat_messages = list(chat_session.ask("hello chat", stream=True, temperature=0.2))
|
||||
assert len(chat_messages) == 1
|
||||
assert chat_messages[0].content == "chat-answer"
|
||||
assert chat_messages[0].reference == [{"id": "chunk-1"}]
|
||||
|
||||
chat_done_messages = list(chat_done_session.ask("hello done", stream=True))
|
||||
assert len(chat_done_messages) == 1
|
||||
assert chat_done_messages[0].content == "chat-done"
|
||||
|
||||
agent_messages = list(agent_session.ask("hello agent", stream=True, top_p=0.8))
|
||||
assert len(agent_messages) == 1
|
||||
assert agent_messages[0].content == "agent-answer"
|
||||
|
||||
assert calls[0][1] == "/chats/chat-1/completions"
|
||||
assert calls[0][2]["question"] == "hello chat"
|
||||
assert calls[0][2]["session_id"] == "session-chat"
|
||||
assert calls[0][2]["temperature"] == 0.2
|
||||
assert calls[0][3] is True
|
||||
assert calls[1][1] == "/agents/agent-1/completions"
|
||||
assert calls[1][2]["question"] == "hello agent"
|
||||
assert calls[1][2]["session_id"] == "session-agent"
|
||||
assert calls[1][2]["top_p"] == 0.8
|
||||
assert calls[1][3] is True
|
||||
|
||||
@ -15,6 +15,18 @@
|
||||
#
|
||||
import pytest
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from ragflow_sdk import RAGFlow
|
||||
from ragflow_sdk.modules.session import Message, Session
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
class TestSessionsWithChatAssistantList:
|
||||
@ -215,3 +227,54 @@ class TestSessionsWithChatAssistantList:
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
chat_assistant.list_sessions()
|
||||
assert "You don't own the assistant" in str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_module_error_paths_unit(monkeypatch):
|
||||
client = RAGFlow("token", "http://localhost:9380")
|
||||
|
||||
unknown_session = Session(client, {"id": "session-unknown", "chat_id": "chat-1"})
|
||||
unknown_session._Session__session_type = "unknown" # noqa: SLF001
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
list(unknown_session.ask("hello", stream=False))
|
||||
assert "Unknown session type" in str(exception_info.value)
|
||||
|
||||
bad_json_session = Session(client, {"id": "session-bad-json", "chat_id": "chat-1"})
|
||||
|
||||
class _BadJsonResponse:
|
||||
def json(self):
|
||||
raise ValueError("json decode failed")
|
||||
|
||||
monkeypatch.setattr(bad_json_session, "post", lambda *_args, **_kwargs: _BadJsonResponse())
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
list(bad_json_session.ask("hello", stream=False))
|
||||
assert "Invalid response" in str(exception_info.value)
|
||||
|
||||
ok_json_session = Session(client, {"id": "session-ok-json", "chat_id": "chat-1"})
|
||||
|
||||
class _OkJsonResponse:
|
||||
def json(self):
|
||||
return {"data": {"answer": "ok-answer", "reference": {"chunks": [{"id": "chunk-ok"}]}}}
|
||||
|
||||
monkeypatch.setattr(ok_json_session, "post", lambda *_args, **_kwargs: _OkJsonResponse())
|
||||
ok_messages = list(ok_json_session.ask("hello", stream=False))
|
||||
assert len(ok_messages) == 1
|
||||
assert ok_messages[0].content == "ok-answer"
|
||||
assert ok_messages[0].reference == [{"id": "chunk-ok"}]
|
||||
|
||||
transport_session = Session(client, {"id": "session-transport", "chat_id": "chat-1"})
|
||||
monkeypatch.setattr(
|
||||
transport_session,
|
||||
"post",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("transport boom")),
|
||||
)
|
||||
with pytest.raises(RuntimeError) as exception_info:
|
||||
list(transport_session.ask("hello", stream=False))
|
||||
assert "transport boom" in str(exception_info.value)
|
||||
|
||||
message = Message(client, {})
|
||||
assert message.content == "Hi! I am your assistant, can I help you?"
|
||||
assert message.reference is None
|
||||
assert message.role == "assistant"
|
||||
assert message.prompt is None
|
||||
assert message.id is None
|
||||
|
||||
@ -0,0 +1,197 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, payload=None, err=None):
|
||||
self._payload = payload or {}
|
||||
self._err = err
|
||||
|
||||
def raise_for_status(self):
|
||||
if self._err:
|
||||
raise self._err
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
def _base_config(scope="openid profile"):
|
||||
return {
|
||||
"client_id": "client-1",
|
||||
"client_secret": "secret-1",
|
||||
"authorization_url": "https://issuer.example/authorize",
|
||||
"token_url": "https://issuer.example/token",
|
||||
"userinfo_url": "https://issuer.example/userinfo",
|
||||
"redirect_uri": "https://app.example/callback",
|
||||
"scope": scope,
|
||||
}
|
||||
|
||||
|
||||
def _load_oauth_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
http_client_mod = ModuleType("common.http_client")
|
||||
|
||||
async def _default_async_request(*_args, **_kwargs):
|
||||
return _FakeResponse({})
|
||||
|
||||
def _default_sync_request(*_args, **_kwargs):
|
||||
return _FakeResponse({})
|
||||
|
||||
http_client_mod.async_request = _default_async_request
|
||||
http_client_mod.sync_request = _default_sync_request
|
||||
monkeypatch.setitem(sys.modules, "common.http_client", http_client_mod)
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
apps_pkg = ModuleType("api.apps")
|
||||
apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
|
||||
auth_pkg = ModuleType("api.apps.auth")
|
||||
auth_pkg.__path__ = [str(repo_root / "api" / "apps" / "auth")]
|
||||
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
|
||||
monkeypatch.setitem(sys.modules, "api.apps.auth", auth_pkg)
|
||||
|
||||
sys.modules.pop("api.apps.auth.oauth", None)
|
||||
oauth_path = repo_root / "api" / "apps" / "auth" / "oauth.py"
|
||||
oauth_spec = importlib.util.spec_from_file_location("api.apps.auth.oauth", oauth_path)
|
||||
oauth_module = importlib.util.module_from_spec(oauth_spec)
|
||||
monkeypatch.setitem(sys.modules, "api.apps.auth.oauth", oauth_module)
|
||||
oauth_spec.loader.exec_module(oauth_module)
|
||||
return oauth_module
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_oauth_client_sync_matrix_unit(monkeypatch):
|
||||
oauth_module = _load_oauth_module(monkeypatch)
|
||||
client = oauth_module.OAuthClient(_base_config())
|
||||
|
||||
assert client.client_id == "client-1"
|
||||
assert client.client_secret == "secret-1"
|
||||
assert client.authorization_url.endswith("/authorize")
|
||||
assert client.token_url.endswith("/token")
|
||||
assert client.userinfo_url.endswith("/userinfo")
|
||||
assert client.redirect_uri.endswith("/callback")
|
||||
assert client.scope == "openid profile"
|
||||
assert client.http_request_timeout == 7
|
||||
|
||||
info = oauth_module.UserInfo("u@example.com", "user1", "User One", "avatar-url")
|
||||
assert info.to_dict() == {
|
||||
"email": "u@example.com",
|
||||
"username": "user1",
|
||||
"nickname": "User One",
|
||||
"avatar_url": "avatar-url",
|
||||
}
|
||||
|
||||
auth_url = client.get_authorization_url(state="s p/a?ce")
|
||||
parsed = urllib.parse.urlparse(auth_url)
|
||||
query = urllib.parse.parse_qs(parsed.query)
|
||||
assert parsed.scheme == "https"
|
||||
assert query["client_id"] == ["client-1"]
|
||||
assert query["redirect_uri"] == ["https://app.example/callback"]
|
||||
assert query["response_type"] == ["code"]
|
||||
assert query["scope"] == ["openid profile"]
|
||||
assert query["state"] == ["s p/a?ce"]
|
||||
|
||||
no_scope_client = oauth_module.OAuthClient(_base_config(scope=None))
|
||||
no_scope_query = urllib.parse.parse_qs(urllib.parse.urlparse(no_scope_client.get_authorization_url()).query)
|
||||
assert "scope" not in no_scope_query
|
||||
|
||||
call_log = []
|
||||
|
||||
def _sync_ok(method, url, data=None, headers=None, timeout=None):
|
||||
call_log.append((method, url, data, headers, timeout))
|
||||
if url.endswith("/token"):
|
||||
return _FakeResponse({"access_token": "token-1"})
|
||||
return _FakeResponse({"email": "user@example.com", "picture": "id-picture"})
|
||||
|
||||
monkeypatch.setattr(oauth_module, "sync_request", _sync_ok)
|
||||
token = client.exchange_code_for_token("code-1")
|
||||
assert token["access_token"] == "token-1"
|
||||
user_info = client.fetch_user_info("access-1")
|
||||
assert isinstance(user_info, oauth_module.UserInfo)
|
||||
assert user_info.to_dict() == {
|
||||
"email": "user@example.com",
|
||||
"username": "user",
|
||||
"nickname": "user",
|
||||
"avatar_url": "id-picture",
|
||||
}
|
||||
assert call_log[0][0] == "POST"
|
||||
assert call_log[0][3]["Accept"] == "application/json"
|
||||
assert call_log[1][0] == "GET"
|
||||
assert call_log[1][3]["Authorization"] == "Bearer access-1"
|
||||
|
||||
normalized = client.normalize_user_info(
|
||||
{"email": "fallback@example.com", "username": "fallback-user", "nickname": "fallback-nick", "avatar_url": "direct-avatar"}
|
||||
)
|
||||
assert normalized.to_dict()["avatar_url"] == "direct-avatar"
|
||||
|
||||
monkeypatch.setattr(oauth_module, "sync_request", lambda *_args, **_kwargs: _FakeResponse(err=RuntimeError("status boom")))
|
||||
with pytest.raises(ValueError, match="Failed to exchange authorization code for token: status boom"):
|
||||
client.exchange_code_for_token("code-2")
|
||||
with pytest.raises(ValueError, match="Failed to fetch user info: status boom"):
|
||||
client.fetch_user_info("access-2")
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_oauth_client_async_matrix_unit(monkeypatch):
|
||||
oauth_module = _load_oauth_module(monkeypatch)
|
||||
client = oauth_module.OAuthClient(_base_config())
|
||||
|
||||
async def _async_ok(method, url, data=None, headers=None, **kwargs):
|
||||
_ = (method, data, headers, kwargs.get("timeout"))
|
||||
if url.endswith("/token"):
|
||||
return _FakeResponse({"access_token": "token-async"})
|
||||
return _FakeResponse({"email": "async@example.com", "username": "async-user", "nickname": "Async User", "avatar_url": "async-avatar"})
|
||||
|
||||
monkeypatch.setattr(oauth_module, "async_request", _async_ok)
|
||||
token = asyncio.run(client.async_exchange_code_for_token("code-a"))
|
||||
assert token["access_token"] == "token-async"
|
||||
info = asyncio.run(client.async_fetch_user_info("async-token"))
|
||||
assert info.to_dict() == {
|
||||
"email": "async@example.com",
|
||||
"username": "async-user",
|
||||
"nickname": "Async User",
|
||||
"avatar_url": "async-avatar",
|
||||
}
|
||||
|
||||
async def _async_fail(*_args, **_kwargs):
|
||||
return _FakeResponse(err=RuntimeError("async boom"))
|
||||
|
||||
monkeypatch.setattr(oauth_module, "async_request", _async_fail)
|
||||
with pytest.raises(ValueError, match="Failed to exchange authorization code for token: async boom"):
|
||||
asyncio.run(client.async_exchange_code_for_token("code-b"))
|
||||
with pytest.raises(ValueError, match="Failed to fetch user info: async boom"):
|
||||
asyncio.run(client.async_fetch_user_info("async-token-2"))
|
||||
File diff suppressed because it is too large
Load Diff
@ -17,6 +17,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
@ -71,6 +72,7 @@ class _DummyRetCode:
|
||||
SUCCESS = 0
|
||||
DATA_ERROR = 102
|
||||
EXCEPTION_ERROR = 100
|
||||
OPERATING_ERROR = 103
|
||||
|
||||
|
||||
class _DummyParserType:
|
||||
@ -204,6 +206,7 @@ def _load_chunk_module(monkeypatch):
|
||||
class _DummyLLMType:
|
||||
EMBEDDING = SimpleNamespace(value="embedding")
|
||||
CHAT = SimpleNamespace(value="chat")
|
||||
RERANK = SimpleNamespace(value="rerank")
|
||||
|
||||
constants_mod.RetCode = _DummyRetCode
|
||||
constants_mod.LLMType = _DummyLLMType
|
||||
@ -365,6 +368,11 @@ def _set_request_json(monkeypatch, module, payload):
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_chunk_exception_branches_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
@ -663,7 +671,180 @@ def test_create_chunk_guards_pagerank_and_success_unit(monkeypatch):
|
||||
)
|
||||
res = _run(module.create())
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["chunk_id"], res
|
||||
assert module.settings.docStoreConn.inserted, "insert should be called"
|
||||
inserted = module.settings.docStoreConn.inserted[-1]
|
||||
assert "pagerank_flt" in inserted
|
||||
assert module.DocumentService.increment_calls, "increment_chunk_num should be called"
|
||||
|
||||
async def _raise_thread_pool(_func):
|
||||
raise RuntimeError("create tp boom")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _raise_thread_pool)
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk"})
|
||||
res = _run(module.create())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
|
||||
assert "create tp boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_retrieval_test_branch_matrix_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
module.request = SimpleNamespace(headers={"X-Request-ID": "req-r"}, args={})
|
||||
|
||||
applied_filters = []
|
||||
llm_calls = []
|
||||
cross_calls = []
|
||||
keyword_calls = []
|
||||
|
||||
async def _apply_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids):
|
||||
applied_filters.append(
|
||||
{
|
||||
"meta_data_filter": meta_data_filter,
|
||||
"metas": metas,
|
||||
"question": question,
|
||||
"chat_mdl": chat_mdl,
|
||||
"local_doc_ids": list(local_doc_ids),
|
||||
}
|
||||
)
|
||||
return ["doc-filtered"]
|
||||
|
||||
async def _cross_languages(_tenant_id, _dialog, question, langs):
|
||||
cross_calls.append((question, tuple(langs)))
|
||||
return f"{question}-xl"
|
||||
|
||||
async def _keyword_extraction(_chat_mdl, question):
|
||||
keyword_calls.append(question)
|
||||
return "-kw"
|
||||
|
||||
class _Retriever:
|
||||
def __init__(self, mode="ok"):
|
||||
self.mode = mode
|
||||
self.retrieval_questions = []
|
||||
|
||||
async def retrieval(self, question, *_args, **_kwargs):
|
||||
if self.mode == "not_found":
|
||||
raise Exception("boom not_found boom")
|
||||
if self.mode == "explode":
|
||||
raise RuntimeError("retrieval boom")
|
||||
self.retrieval_questions.append(question)
|
||||
return {"chunks": [{"id": "c1", "vector": [0.1], "content_with_weight": "chunk-content"}]}
|
||||
|
||||
def retrieval_by_children(self, chunks, _tenant_ids):
|
||||
return list(chunks)
|
||||
|
||||
class _KgRetriever:
|
||||
async def retrieval(self, *_args, **_kwargs):
|
||||
return {"id": "kg-1", "content_with_weight": "kg-content"}
|
||||
|
||||
class _NoContentKgRetriever:
|
||||
async def retrieval(self, *_args, **_kwargs):
|
||||
return {"id": "kg-2", "content_with_weight": ""}
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *args, **kwargs: llm_calls.append((args, kwargs)) or SimpleNamespace())
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: [{"meta": "v"}], raising=False)
|
||||
monkeypatch.setattr(module, "apply_meta_data_filter", _apply_filter)
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"meta_data_filter": {"method": "auto"}, "chat_id": "chat-1"}}, raising=False)
|
||||
monkeypatch.setattr(module, "cross_languages", _cross_languages)
|
||||
monkeypatch.setattr(module, "keyword_extraction", _keyword_extraction)
|
||||
monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: ["lbl"])
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [_DummyTenant("tenant-1")])
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False, raising=False)
|
||||
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "question": "q", "search_id": "search-1"})
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == module.RetCode.OPERATING_ERROR, res
|
||||
assert "Only owner of dataset authorized for this operation." in res["message"], res
|
||||
assert applied_filters and applied_filters[-1]["meta_data_filter"]["method"] == "auto"
|
||||
assert llm_calls, "search_id metadata auto branch should instantiate chat model"
|
||||
|
||||
_set_request_json(monkeypatch, module, {"kb_id": [], "question": "q"})
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert "Please specify dataset firstly." in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: True, raising=False)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None), raising=False)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{"kb_id": ["kb-1"], "question": "q", "meta_data_filter": {"method": "semi_auto"}},
|
||||
)
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert "Knowledgebase not found!" in res["message"], res
|
||||
|
||||
retriever = _Retriever(mode="ok")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-kb", embd_id="embd-1")), raising=False)
|
||||
monkeypatch.setattr(module.settings, "retriever", retriever)
|
||||
monkeypatch.setattr(module.settings, "kg_retriever", _KgRetriever(), raising=False)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"kb_id": ["kb-1"],
|
||||
"question": "q",
|
||||
"cross_languages": ["fr"],
|
||||
"rerank_id": "rerank-1",
|
||||
"keyword": True,
|
||||
"use_kg": True,
|
||||
},
|
||||
)
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == 0, res
|
||||
assert cross_calls[-1] == ("q", ("fr",))
|
||||
assert keyword_calls[-1] == "q-xl"
|
||||
assert retriever.retrieval_questions[-1] == "q-xl-kw"
|
||||
assert res["data"]["chunks"][0]["id"] == "kg-1", res
|
||||
assert all("vector" not in chunk for chunk in res["data"]["chunks"])
|
||||
|
||||
monkeypatch.setattr(module.settings, "kg_retriever", _NoContentKgRetriever(), raising=False)
|
||||
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q", "use_kg": True})
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["chunks"][0]["id"] == "c1", res
|
||||
|
||||
monkeypatch.setattr(module.settings, "retriever", _Retriever(mode="not_found"))
|
||||
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q"})
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert "No chunk found! Check the chunk status please!" in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.settings, "retriever", _Retriever(mode="explode"))
|
||||
_set_request_json(monkeypatch, module, {"kb_id": ["kb-1"], "question": "q"})
|
||||
res = _run(module.retrieval_test())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
|
||||
assert "retrieval boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_knowledge_graph_repeat_deal_matrix_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
module.request = SimpleNamespace(args={"doc_id": "doc-1"}, headers={})
|
||||
|
||||
payload = {
|
||||
"id": "root",
|
||||
"children": [
|
||||
{"id": "dup"},
|
||||
{"id": "dup", "children": [{"id": "dup"}]},
|
||||
],
|
||||
}
|
||||
|
||||
class _SRes:
|
||||
ids = ["bad-json", "mind-map"]
|
||||
field = {
|
||||
"bad-json": {"knowledge_graph_kwd": "graph", "content_with_weight": "{bad json"},
|
||||
"mind-map": {"knowledge_graph_kwd": "mind_map", "content_with_weight": json.dumps(payload)},
|
||||
}
|
||||
|
||||
async def _search(*_args, **_kwargs):
|
||||
return _SRes()
|
||||
|
||||
monkeypatch.setattr(module.settings.retriever, "search", _search)
|
||||
res = _run(module.knowledge_graph())
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["graph"] == {}, res
|
||||
mind_map = res["data"]["mind_map"]
|
||||
assert mind_map["children"][0]["id"] == "dup", res
|
||||
assert mind_map["children"][1]["id"] == "dup(1)", res
|
||||
assert mind_map["children"][1]["children"][0]["id"] == "dup(2)", res
|
||||
|
||||
@ -0,0 +1,711 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _Args(dict):
|
||||
def get(self, key, default=None, type=None):
|
||||
value = super().get(key, default)
|
||||
if type is None:
|
||||
return value
|
||||
try:
|
||||
return type(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def to_dict(self, flat=True):
|
||||
return dict(self)
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, body, status_code):
|
||||
self.body = body
|
||||
self.status_code = status_code
|
||||
self.headers = {}
|
||||
|
||||
|
||||
class _FakeConnectorRecord:
|
||||
def __init__(self, payload):
|
||||
self._payload = payload
|
||||
|
||||
def to_dict(self):
|
||||
return dict(self._payload)
|
||||
|
||||
|
||||
class _FakeCredentials:
|
||||
def __init__(self, raw='{"refresh_token":"rt","access_token":"at"}'):
|
||||
self._raw = raw
|
||||
|
||||
def to_json(self):
|
||||
return self._raw
|
||||
|
||||
|
||||
class _FakeFlow:
|
||||
def __init__(self, client_config, scopes):
|
||||
self.client_config = client_config
|
||||
self.scopes = scopes
|
||||
self.redirect_uri = None
|
||||
self.credentials = _FakeCredentials()
|
||||
self.auth_kwargs = None
|
||||
self.token_code = None
|
||||
|
||||
def authorization_url(self, **kwargs):
|
||||
self.auth_kwargs = dict(kwargs)
|
||||
return f"https://oauth.example/{kwargs['state']}", kwargs["state"]
|
||||
|
||||
def fetch_token(self, code):
|
||||
self.token_code = code
|
||||
|
||||
|
||||
class _FakeBoxToken:
|
||||
def __init__(self, access_token, refresh_token):
|
||||
self.access_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
|
||||
|
||||
class _FakeBoxOAuth:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.exchange_code = None
|
||||
|
||||
def get_authorize_url(self, options):
|
||||
return f"https://box.example/auth?state={options.state}&redirect={options.redirect_uri}"
|
||||
|
||||
def get_tokens_authorization_code_grant(self, code):
|
||||
self.exchange_code = code
|
||||
|
||||
def retrieve_token(self):
|
||||
return _FakeBoxToken("box-access", "box-refresh")
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self):
|
||||
self.store = {}
|
||||
self.set_calls = []
|
||||
self.deleted = []
|
||||
|
||||
def get(self, key):
|
||||
return self.store.get(key)
|
||||
|
||||
def set_obj(self, key, obj, ttl):
|
||||
self.set_calls.append((key, obj, ttl))
|
||||
self.store[key] = json.dumps(obj)
|
||||
|
||||
def delete(self, key):
|
||||
self.deleted.append(key)
|
||||
self.store.pop(key, None)
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request(module, *, args=None, json_body=None):
|
||||
module.request = SimpleNamespace(
|
||||
args=_Args(args or {}),
|
||||
json=_AwaitableValue({} if json_body is None else json_body),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_connector_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="tenant-1")
|
||||
apps_mod.login_required = lambda fn: fn
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
db_mod = ModuleType("api.db")
|
||||
db_mod.InputType = SimpleNamespace(POLL="POLL")
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_mod)
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
connector_service_mod = ModuleType("api.db.services.connector_service")
|
||||
|
||||
class _StubConnectorService:
|
||||
@staticmethod
|
||||
def update_by_id(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def save(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_connector_id):
|
||||
return True, _FakeConnectorRecord({"id": _connector_id})
|
||||
|
||||
@staticmethod
|
||||
def list(_tenant_id):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def resume(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def rebuild(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def delete_by_id(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
class _StubSyncLogsService:
|
||||
@staticmethod
|
||||
def list_sync_tasks(*_args, **_kwargs):
|
||||
return [], 0
|
||||
|
||||
connector_service_mod.ConnectorService = _StubConnectorService
|
||||
connector_service_mod.SyncLogsService = _StubSyncLogsService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.connector_service", connector_service_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
async def _get_request_json():
|
||||
return {}
|
||||
|
||||
api_utils_mod.get_request_json = _get_request_json
|
||||
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
constants_mod.RetCode = SimpleNamespace(
|
||||
ARGUMENT_ERROR=101,
|
||||
SERVER_ERROR=500,
|
||||
RUNNING=102,
|
||||
PERMISSION_ERROR=403,
|
||||
)
|
||||
constants_mod.TaskStatus = SimpleNamespace(SCHEDULE="schedule", CANCEL="cancel")
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
config_mod = ModuleType("common.data_source.config")
|
||||
config_mod.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = "https://example.com/drive"
|
||||
config_mod.GMAIL_WEB_OAUTH_REDIRECT_URI = "https://example.com/gmail"
|
||||
config_mod.BOX_WEB_OAUTH_REDIRECT_URI = "https://example.com/box"
|
||||
config_mod.DocumentSource = SimpleNamespace(GMAIL="gmail", GOOGLE_DRIVE="google-drive")
|
||||
monkeypatch.setitem(sys.modules, "common.data_source.config", config_mod)
|
||||
|
||||
google_constants_mod = ModuleType("common.data_source.google_util.constant")
|
||||
google_constants_mod.WEB_OAUTH_POPUP_TEMPLATE = (
|
||||
"<html><head><title>{title}</title></head>"
|
||||
"<body><h1>{heading}</h1><p>{message}</p><script>{payload_json}</script><script>{auto_close}</script></body></html>"
|
||||
)
|
||||
google_constants_mod.GOOGLE_SCOPES = {
|
||||
config_mod.DocumentSource.GMAIL: ["scope-gmail"],
|
||||
config_mod.DocumentSource.GOOGLE_DRIVE: ["scope-drive"],
|
||||
}
|
||||
monkeypatch.setitem(sys.modules, "common.data_source.google_util.constant", google_constants_mod)
|
||||
|
||||
misc_mod = ModuleType("common.misc_utils")
|
||||
misc_mod.get_uuid = lambda: "uuid-from-helper"
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_mod)
|
||||
|
||||
rag_pkg = ModuleType("rag")
|
||||
rag_pkg.__path__ = [str(repo_root / "rag")]
|
||||
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
|
||||
|
||||
rag_utils_pkg = ModuleType("rag.utils")
|
||||
rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")]
|
||||
monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg)
|
||||
|
||||
redis_mod = ModuleType("rag.utils.redis_conn")
|
||||
redis_mod.REDIS_CONN = _FakeRedis()
|
||||
monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod)
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = SimpleNamespace(args=_Args(), json=_AwaitableValue({}))
|
||||
|
||||
async def _make_response(body, status_code):
|
||||
return _FakeResponse(body, status_code)
|
||||
|
||||
quart_mod.make_response = _make_response
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
google_pkg = ModuleType("google_auth_oauthlib")
|
||||
google_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "google_auth_oauthlib", google_pkg)
|
||||
|
||||
google_flow_mod = ModuleType("google_auth_oauthlib.flow")
|
||||
|
||||
class _StubFlow:
|
||||
@classmethod
|
||||
def from_client_config(cls, client_config, scopes):
|
||||
return _FakeFlow(client_config, scopes)
|
||||
|
||||
google_flow_mod.Flow = _StubFlow
|
||||
monkeypatch.setitem(sys.modules, "google_auth_oauthlib.flow", google_flow_mod)
|
||||
|
||||
box_mod = ModuleType("box_sdk_gen")
|
||||
|
||||
class _OAuthConfig:
|
||||
def __init__(self, client_id, client_secret):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
||||
class _GetAuthorizeUrlOptions:
|
||||
def __init__(self, redirect_uri, state):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.state = state
|
||||
|
||||
box_mod.BoxOAuth = _FakeBoxOAuth
|
||||
box_mod.OAuthConfig = _OAuthConfig
|
||||
box_mod.GetAuthorizeUrlOptions = _GetAuthorizeUrlOptions
|
||||
monkeypatch.setitem(sys.modules, "box_sdk_gen", box_mod)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "connector_app.py"
|
||||
spec = importlib.util.spec_from_file_location("test_connector_routes_unit", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_connector_basic_routes_and_task_controls(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
|
||||
async def _no_sleep(_secs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(module.asyncio, "sleep", _no_sleep)
|
||||
|
||||
records = {"conn-1": _FakeConnectorRecord({"id": "conn-1", "source": "drive"})}
|
||||
update_calls = []
|
||||
save_calls = []
|
||||
resume_calls = []
|
||||
delete_calls = []
|
||||
|
||||
monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda cid, payload: update_calls.append((cid, payload)))
|
||||
|
||||
def _save(**payload):
|
||||
save_calls.append(payload)
|
||||
records[payload["id"]] = _FakeConnectorRecord(payload)
|
||||
|
||||
monkeypatch.setattr(module.ConnectorService, "save", _save)
|
||||
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, records[cid]))
|
||||
monkeypatch.setattr(module.ConnectorService, "list", lambda tenant_id: [{"id": "listed", "tenant": tenant_id}])
|
||||
monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda cid, page, page_size: ([{"id": "log-1"}], 9))
|
||||
monkeypatch.setattr(module.ConnectorService, "resume", lambda cid, status: resume_calls.append((cid, status)))
|
||||
monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda cid: delete_calls.append(cid))
|
||||
monkeypatch.setattr(module, "get_uuid", lambda: "generated-id")
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}}),
|
||||
)
|
||||
res = _run(module.set_connector())
|
||||
assert update_calls == [("conn-1", {"refresh_freq": 7, "config": {"x": 1}})]
|
||||
assert res["data"]["id"] == "conn-1"
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"name": "new", "source": "gmail", "config": {"y": 2}}),
|
||||
)
|
||||
res = _run(module.set_connector())
|
||||
assert save_calls[-1]["id"] == "generated-id"
|
||||
assert save_calls[-1]["tenant_id"] == "tenant-1"
|
||||
assert save_calls[-1]["input_type"] == module.InputType.POLL
|
||||
assert res["data"]["id"] == "generated-id"
|
||||
|
||||
list_res = module.list_connector()
|
||||
assert list_res["data"] == [{"id": "listed", "tenant": "tenant-1"}]
|
||||
|
||||
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda _cid: (False, None))
|
||||
missing_res = module.get_connector("missing")
|
||||
assert missing_res["message"] == "Can't find this Connector!"
|
||||
|
||||
monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, _FakeConnectorRecord({"id": cid})))
|
||||
found_res = module.get_connector("conn-2")
|
||||
assert found_res["data"]["id"] == "conn-2"
|
||||
|
||||
_set_request(module, args={"page": "2", "page_size": "7"})
|
||||
logs_res = module.list_logs("conn-log")
|
||||
assert logs_res["data"] == {"total": 9, "logs": [{"id": "log-1"}]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": True}))
|
||||
assert _run(module.resume("conn-r1"))["data"] is True
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": False}))
|
||||
assert _run(module.resume("conn-r2"))["data"] is True
|
||||
assert ("conn-r1", module.TaskStatus.SCHEDULE) in resume_calls
|
||||
assert ("conn-r2", module.TaskStatus.CANCEL) in resume_calls
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"kb_id": "kb-1"}))
|
||||
monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: "rebuild-failed")
|
||||
failed_rebuild = _run(module.rebuild("conn-rb"))
|
||||
assert failed_rebuild["code"] == module.RetCode.SERVER_ERROR
|
||||
assert failed_rebuild["data"] is False
|
||||
|
||||
monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: None)
|
||||
ok_rebuild = _run(module.rebuild("conn-rb"))
|
||||
assert ok_rebuild["data"] is True
|
||||
|
||||
rm_res = module.rm_connector("conn-rm")
|
||||
assert rm_res["data"] is True
|
||||
assert ("conn-rm", module.TaskStatus.CANCEL) in resume_calls
|
||||
assert delete_calls == ["conn-rm"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_connector_oauth_helper_functions(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
|
||||
assert module._web_state_cache_key("flow-a", "gmail") == "gmail_web_flow_state:flow-a"
|
||||
assert module._web_result_cache_key("flow-b", "google-drive") == "google-drive_web_flow_result:flow-b"
|
||||
|
||||
creds_dict = {"web": {"client_id": "id"}}
|
||||
assert module._load_credentials(creds_dict) == creds_dict
|
||||
assert module._load_credentials(json.dumps(creds_dict)) == creds_dict
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid Google credentials JSON"):
|
||||
module._load_credentials("{not-json")
|
||||
|
||||
assert module._get_web_client_config(creds_dict) == {"web": {"client_id": "id"}}
|
||||
with pytest.raises(ValueError, match="must include a 'web'"):
|
||||
module._get_web_client_config({"installed": {"client_id": "id"}})
|
||||
|
||||
popup_ok = _run(module._render_web_oauth_popup("flow-1", True, "done", "gmail"))
|
||||
assert popup_ok.status_code == 200
|
||||
assert popup_ok.headers["Content-Type"] == "text/html; charset=utf-8"
|
||||
assert "Authorization complete" in popup_ok.body
|
||||
assert "ragflow-gmail-oauth" in popup_ok.body
|
||||
|
||||
popup_error = _run(module._render_web_oauth_popup("flow-2", False, "<denied>", "google-drive"))
|
||||
assert popup_error.status_code == 200
|
||||
assert "Authorization failed" in popup_error.body
|
||||
assert "<denied>" in popup_error.body
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_start_google_web_oauth_matrix(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
|
||||
redis = _FakeRedis()
|
||||
monkeypatch.setattr(module, "REDIS_CONN", redis)
|
||||
monkeypatch.setattr(module.time, "time", lambda: 1700000000)
|
||||
|
||||
flow_calls = []
|
||||
|
||||
def _from_client_config(client_config, scopes):
|
||||
flow = _FakeFlow(client_config, scopes)
|
||||
flow_calls.append(flow)
|
||||
return flow
|
||||
|
||||
monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config))
|
||||
|
||||
_set_request(module, args={"type": "invalid"})
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{}"}))
|
||||
invalid_type = _run(module.start_google_web_oauth())
|
||||
assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "")
|
||||
_set_request(module, args={"type": "gmail"})
|
||||
missing_redirect = _run(module.start_google_web_oauth())
|
||||
assert missing_redirect["code"] == module.RetCode.SERVER_ERROR
|
||||
|
||||
monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "https://example.com/gmail")
|
||||
monkeypatch.setattr(module, "GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "https://example.com/drive")
|
||||
|
||||
_set_request(module, args={"type": "google-drive"})
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{invalid-json"}))
|
||||
invalid_credentials = _run(module.start_google_web_oauth())
|
||||
assert invalid_credentials["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id"}, "refresh_token": "rt"})}),
|
||||
)
|
||||
has_refresh_token = _run(module.start_google_web_oauth())
|
||||
assert has_refresh_token["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": json.dumps({"installed": {"x": 1}})}))
|
||||
missing_web = _run(module.start_google_web_oauth())
|
||||
assert missing_web["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
ids = iter(["flow-gmail", "flow-drive"])
|
||||
monkeypatch.setattr(module.uuid, "uuid4", lambda: next(ids))
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id", "client_secret": "secret"}})}),
|
||||
)
|
||||
|
||||
_set_request(module, args={"type": "gmail"})
|
||||
gmail_ok = _run(module.start_google_web_oauth())
|
||||
assert gmail_ok["code"] == 0
|
||||
assert gmail_ok["data"]["flow_id"] == "flow-gmail"
|
||||
assert gmail_ok["data"]["authorization_url"].endswith("flow-gmail")
|
||||
|
||||
_set_request(module, args={})
|
||||
drive_ok = _run(module.start_google_web_oauth())
|
||||
assert drive_ok["code"] == 0
|
||||
assert drive_ok["data"]["flow_id"] == "flow-drive"
|
||||
assert drive_ok["data"]["authorization_url"].endswith("flow-drive")
|
||||
|
||||
assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GMAIL] for call in flow_calls)
|
||||
assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE] for call in flow_calls)
|
||||
assert "gmail_web_flow_state:flow-gmail" in redis.store
|
||||
assert "google-drive_web_flow_state:flow-drive" in redis.store
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_google_web_oauth_callbacks_matrix(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
|
||||
flow_calls = []
|
||||
|
||||
def _from_client_config(client_config, scopes):
|
||||
flow = _FakeFlow(client_config, scopes)
|
||||
flow_calls.append(flow)
|
||||
return flow
|
||||
|
||||
monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config))
|
||||
|
||||
callback_specs = [
|
||||
(
|
||||
module.google_gmail_web_oauth_callback,
|
||||
"gmail",
|
||||
module.GMAIL_WEB_OAUTH_REDIRECT_URI,
|
||||
module.GOOGLE_SCOPES[module.DocumentSource.GMAIL],
|
||||
),
|
||||
(
|
||||
module.google_drive_web_oauth_callback,
|
||||
"google-drive",
|
||||
module.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI,
|
||||
module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE],
|
||||
),
|
||||
]
|
||||
|
||||
for callback, source, expected_redirect, expected_scopes in callback_specs:
|
||||
redis = _FakeRedis()
|
||||
monkeypatch.setattr(module, "REDIS_CONN", redis)
|
||||
|
||||
_set_request(module, args={})
|
||||
missing_state = _run(callback())
|
||||
assert "Missing OAuth state parameter." in missing_state.body
|
||||
|
||||
_set_request(module, args={"state": "sid"})
|
||||
expired_state = _run(callback())
|
||||
assert "Authorization session expired" in expired_state.body
|
||||
|
||||
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({"user_id": "tenant-1"})
|
||||
_set_request(module, args={"state": "sid"})
|
||||
invalid_state = _run(callback())
|
||||
assert "Authorization session was invalid" in invalid_state.body
|
||||
assert module._web_state_cache_key("sid", source) in redis.deleted
|
||||
|
||||
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
|
||||
"user_id": "tenant-1",
|
||||
"client_config": {"web": {"client_id": "cid"}},
|
||||
})
|
||||
_set_request(module, args={"state": "sid", "error": "denied", "error_description": "permission denied"})
|
||||
oauth_error = _run(callback())
|
||||
assert "permission denied" in oauth_error.body
|
||||
|
||||
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
|
||||
"user_id": "tenant-1",
|
||||
"client_config": {"web": {"client_id": "cid"}},
|
||||
})
|
||||
_set_request(module, args={"state": "sid"})
|
||||
missing_code = _run(callback())
|
||||
assert "Missing authorization code" in missing_code.body
|
||||
|
||||
redis.store[module._web_state_cache_key("sid", source)] = json.dumps({
|
||||
"user_id": "tenant-1",
|
||||
"client_config": {"web": {"client_id": "cid"}},
|
||||
})
|
||||
_set_request(module, args={"state": "sid", "code": "code-123"})
|
||||
success = _run(callback())
|
||||
assert "Authorization completed successfully." in success.body
|
||||
|
||||
result_key = module._web_result_cache_key("sid", source)
|
||||
assert result_key in redis.store
|
||||
assert module._web_state_cache_key("sid", source) in redis.deleted
|
||||
|
||||
assert flow_calls[-1].redirect_uri == expected_redirect
|
||||
assert flow_calls[-1].scopes == expected_scopes
|
||||
assert flow_calls[-1].token_code == "code-123"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_poll_google_web_result_matrix(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
redis = _FakeRedis()
|
||||
monkeypatch.setattr(module, "REDIS_CONN", redis)
|
||||
|
||||
_set_request(module, args={"type": "invalid"}, json_body={"flow_id": "flow-1"})
|
||||
invalid_type = _run(module.poll_google_web_result())
|
||||
assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
|
||||
pending = _run(module.poll_google_web_result())
|
||||
assert pending["code"] == module.RetCode.RUNNING
|
||||
|
||||
redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps(
|
||||
{"user_id": "another-user", "credentials": "token-x"}
|
||||
)
|
||||
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
|
||||
permission_error = _run(module.poll_google_web_result())
|
||||
assert permission_error["code"] == module.RetCode.PERMISSION_ERROR
|
||||
|
||||
redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps(
|
||||
{"user_id": "tenant-1", "credentials": "token-ok"}
|
||||
)
|
||||
_set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"})
|
||||
success = _run(module.poll_google_web_result())
|
||||
assert success["code"] == 0
|
||||
assert success["data"] == {"credentials": "token-ok"}
|
||||
assert module._web_result_cache_key("flow-1", "gmail") in redis.deleted
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_box_oauth_start_callback_and_poll_matrix(monkeypatch):
|
||||
module = _load_connector_app(monkeypatch)
|
||||
redis = _FakeRedis()
|
||||
monkeypatch.setattr(module, "REDIS_CONN", redis)
|
||||
|
||||
created_auth = []
|
||||
|
||||
class _TrackingBoxOAuth(_FakeBoxOAuth):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
created_auth.append(self)
|
||||
|
||||
monkeypatch.setattr(module, "BoxOAuth", _TrackingBoxOAuth)
|
||||
monkeypatch.setattr(module.uuid, "uuid4", lambda: "flow-box")
|
||||
monkeypatch.setattr(module.time, "time", lambda: 1800000000)
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({}))
|
||||
missing_params = _run(module.start_box_web_oauth())
|
||||
assert missing_params["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({"client_id": "cid", "client_secret": "sec", "redirect_uri": "https://box.local/callback"}),
|
||||
)
|
||||
start_ok = _run(module.start_box_web_oauth())
|
||||
assert start_ok["code"] == 0
|
||||
assert start_ok["data"]["flow_id"] == "flow-box"
|
||||
assert "authorization_url" in start_ok["data"]
|
||||
assert module._web_state_cache_key("flow-box", "box") in redis.store
|
||||
|
||||
_set_request(module, args={})
|
||||
missing_state = _run(module.box_web_oauth_callback())
|
||||
assert "Missing OAuth parameters." in missing_state.body
|
||||
|
||||
_set_request(module, args={"state": "flow-box"})
|
||||
missing_code = _run(module.box_web_oauth_callback())
|
||||
assert "Missing authorization code from Box." in missing_code.body
|
||||
|
||||
redis.store[module._web_state_cache_key("flow-null", "box")] = "null"
|
||||
_set_request(module, args={"state": "flow-null", "code": "abc"})
|
||||
invalid_session = _run(module.box_web_oauth_callback())
|
||||
assert invalid_session["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
|
||||
redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps(
|
||||
{"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}
|
||||
)
|
||||
_set_request(module, args={"state": "flow-box", "code": "abc", "error": "access_denied", "error_description": "denied"})
|
||||
callback_error = _run(module.box_web_oauth_callback())
|
||||
assert "denied" in callback_error.body
|
||||
|
||||
redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps(
|
||||
{"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"}
|
||||
)
|
||||
_set_request(module, args={"state": "flow-ok", "code": "code-ok"})
|
||||
callback_success = _run(module.box_web_oauth_callback())
|
||||
assert "Authorization completed successfully." in callback_success.body
|
||||
assert created_auth[-1].exchange_code == "code-ok"
|
||||
assert module._web_result_cache_key("flow-ok", "box") in redis.store
|
||||
assert module._web_state_cache_key("flow-ok", "box") in redis.deleted
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"flow_id": "flow-ok"}))
|
||||
redis.store.pop(module._web_result_cache_key("flow-ok", "box"), None)
|
||||
pending = _run(module.poll_box_web_result())
|
||||
assert pending["code"] == module.RetCode.RUNNING
|
||||
|
||||
redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps({"user_id": "another-user"})
|
||||
permission_error = _run(module.poll_box_web_result())
|
||||
assert permission_error["code"] == module.RetCode.PERMISSION_ERROR
|
||||
|
||||
redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps(
|
||||
{"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"}
|
||||
)
|
||||
poll_success = _run(module.poll_box_web_result())
|
||||
assert poll_success["code"] == 0
|
||||
assert poll_success["data"]["credentials"]["access_token"] == "at"
|
||||
assert module._web_result_cache_key("flow-ok", "box") in redis.deleted
|
||||
@ -578,7 +578,189 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
|
||||
@pytest.mark.p2
|
||||
def test_tts_request_parse_entry(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
_set_request_json(monkeypatch, module, {"text": "hello"})
|
||||
_set_request_json(monkeypatch, module, {"text": "A。B"})
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
|
||||
res = _run(module.tts())
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "tts_id": ""}])
|
||||
res = _run(module.tts())
|
||||
assert res["message"] == "No default TTS model is set"
|
||||
|
||||
class _TTSOk:
|
||||
def tts(self, txt):
|
||||
if not txt:
|
||||
return []
|
||||
yield f"chunk-{txt}".encode("utf-8")
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "tts_id": "tts-x"}])
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
|
||||
resp = _run(module.tts())
|
||||
assert resp.mimetype == "audio/mpeg"
|
||||
assert resp.headers.get("Cache-Control") == "no-cache"
|
||||
assert resp.headers.get("Connection") == "keep-alive"
|
||||
assert resp.headers.get("X-Accel-Buffering") == "no"
|
||||
stream_text = _run(_read_sse_text(resp))
|
||||
assert "chunk-A" in stream_text
|
||||
assert "chunk-B" in stream_text
|
||||
|
||||
class _TTSErr:
|
||||
def tts(self, _txt):
|
||||
raise RuntimeError("tts boom")
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSErr())
|
||||
resp = _run(module.tts())
|
||||
stream_text = _run(_read_sse_text(resp))
|
||||
assert '"code": 500' in stream_text
|
||||
assert "**ERROR**: tts boom" in stream_text
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_delete_msg_and_thumbup_matrix_unit(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
|
||||
updates = []
|
||||
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda conv_id, payload: updates.append((conv_id, payload)) or True)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"conversation_id": "missing", "message_id": "pair-1"})
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
|
||||
res = _run(module.delete_msg.__wrapped__())
|
||||
assert res["message"] == "Conversation not found!"
|
||||
|
||||
conv = _DummyConversation(
|
||||
conv_id="conv-del",
|
||||
message=[
|
||||
{"id": "other", "role": "user"},
|
||||
{"id": "pair-1", "role": "user"},
|
||||
{"id": "pair-1", "role": "assistant"},
|
||||
],
|
||||
reference=[{"chunks": [{"id": "c1"}]}],
|
||||
)
|
||||
_set_request_json(monkeypatch, module, {"conversation_id": "conv-del", "message_id": "pair-1"})
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
|
||||
res = _run(module.delete_msg.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert [m["id"] for m in res["data"]["message"]] == ["other"]
|
||||
assert res["data"]["reference"] == []
|
||||
assert updates[-1][0] == "conv-del"
|
||||
|
||||
_set_request_json(monkeypatch, module, {"conversation_id": "missing", "message_id": "assistant-1", "thumbup": True})
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
|
||||
res = _run(module.thumbup.__wrapped__())
|
||||
assert res["message"] == "Conversation not found!"
|
||||
|
||||
conv_up = _DummyConversation(
|
||||
conv_id="conv-up",
|
||||
message=[{"id": "assistant-1", "role": "assistant", "feedback": "old"}],
|
||||
)
|
||||
_set_request_json(monkeypatch, module, {"conversation_id": "conv-up", "message_id": "assistant-1", "thumbup": True})
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv_up))
|
||||
res = _run(module.thumbup.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["message"][0]["thumbup"] is True
|
||||
assert "feedback" not in res["data"]["message"][0]
|
||||
|
||||
conv_down = _DummyConversation(conv_id="conv-down", message=[{"id": "assistant-2", "role": "assistant"}])
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{"conversation_id": "conv-down", "message_id": "assistant-2", "thumbup": False, "feedback": "needs sources"},
|
||||
)
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv_down))
|
||||
res = _run(module.thumbup.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["message"][0]["thumbup"] is False
|
||||
assert res["data"]["message"][0]["feedback"] == "needs sources"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_ask_about_stream_search_config_matrix_unit(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
_set_request_json(monkeypatch, module, {"question": "q", "kb_ids": ["kb-1"], "search_id": "search-1"})
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"mode": "test"}})
|
||||
|
||||
captured = {}
|
||||
|
||||
async def _fake_async_ask(question, kb_ids, uid, search_config=None):
|
||||
captured["question"] = question
|
||||
captured["kb_ids"] = kb_ids
|
||||
captured["uid"] = uid
|
||||
captured["search_config"] = search_config
|
||||
yield {"answer": "first"}
|
||||
raise RuntimeError("ask boom")
|
||||
|
||||
monkeypatch.setattr(module, "async_ask", _fake_async_ask)
|
||||
resp = _run(module.ask_about.__wrapped__())
|
||||
assert resp.headers["Content-Type"] == "text/event-stream; charset=utf-8"
|
||||
sse_text = _run(_read_sse_text(resp))
|
||||
assert '"answer": "first"' in sse_text
|
||||
assert "**ERROR**: ask boom" in sse_text
|
||||
assert '"data": true' in sse_text.lower()
|
||||
assert captured == {"question": "q", "kb_ids": ["kb-1"], "uid": "user-1", "search_config": {"mode": "test"}}
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_mindmap_and_related_questions_matrix_unit(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
|
||||
def _search_detail(_sid):
|
||||
return {
|
||||
"tenant_id": "tenant-x",
|
||||
"search_config": {
|
||||
"kb_ids": ["kb-2", "kb-3"],
|
||||
"chat_id": "chat-x",
|
||||
"llm_setting": {"temperature": 0.2, "parameter": {"k": "v"}},
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", _search_detail)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"question": "mindmap-q", "kb_ids": ["kb-1", "kb-2"], "search_id": "search-1"})
|
||||
mindmap_calls = {}
|
||||
|
||||
async def _gen_ok(question, kb_ids, tenant_id, search_config):
|
||||
mindmap_calls["question"] = question
|
||||
mindmap_calls["kb_ids"] = set(kb_ids)
|
||||
mindmap_calls["tenant_id"] = tenant_id
|
||||
mindmap_calls["search_config"] = search_config
|
||||
return {"nodes": [question]}
|
||||
|
||||
monkeypatch.setattr(module, "gen_mindmap", _gen_ok)
|
||||
res = _run(module.mindmap.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == {"nodes": ["mindmap-q"]}
|
||||
assert mindmap_calls["kb_ids"] == {"kb-1", "kb-2", "kb-3"}
|
||||
assert mindmap_calls["tenant_id"] == "tenant-x"
|
||||
assert set(mindmap_calls["search_config"]["kb_ids"]) == {"kb-1", "kb-2", "kb-3"}
|
||||
|
||||
async def _gen_error(*_args, **_kwargs):
|
||||
return {"error": "mindmap boom"}
|
||||
|
||||
monkeypatch.setattr(module, "gen_mindmap", _gen_error)
|
||||
res = _run(module.mindmap.__wrapped__())
|
||||
assert "mindmap boom" in res["message"]
|
||||
|
||||
llm_calls = {}
|
||||
|
||||
class _FakeChat:
|
||||
async def async_chat(self, prompt, messages, options):
|
||||
llm_calls["prompt"] = prompt
|
||||
llm_calls["messages"] = messages
|
||||
llm_calls["options"] = options
|
||||
return "1. Alpha\n2. Beta\nignored"
|
||||
|
||||
def _fake_bundle(tenant_id, llm_type, chat_id):
|
||||
llm_calls["bundle"] = (tenant_id, llm_type, chat_id)
|
||||
return _FakeChat()
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", _fake_bundle)
|
||||
monkeypatch.setattr(module, "load_prompt", lambda name: f"prompt-{name}")
|
||||
_set_request_json(monkeypatch, module, {"question": "solar", "search_id": "search-1"})
|
||||
res = _run(module.related_questions.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == ["Alpha", "Beta"]
|
||||
assert llm_calls["bundle"][0] == "user-1"
|
||||
assert llm_calls["bundle"][2] == "chat-x"
|
||||
assert llm_calls["options"] == {"temperature": 0.2}
|
||||
assert llm_calls["prompt"] == "prompt-related_question"
|
||||
assert "Keywords: solar" in llm_calls["messages"][0]["content"]
|
||||
|
||||
@ -0,0 +1,778 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _DummyArgs(dict):
|
||||
def get(self, key, default=None, type=None):
|
||||
value = super().get(key, default)
|
||||
if value is None or type is None:
|
||||
return value
|
||||
try:
|
||||
return type(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
class _Field:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.name, "==", other)
|
||||
|
||||
|
||||
class _KB:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
kb_id="kb-1",
|
||||
name="old",
|
||||
tenant_id="tenant-1",
|
||||
parser_id="naive",
|
||||
parser_config=None,
|
||||
embd_id="embd-1",
|
||||
chunk_num=0,
|
||||
pagerank=0,
|
||||
graphrag_task_id="",
|
||||
raptor_task_id="",
|
||||
):
|
||||
self.id = kb_id
|
||||
self.name = name
|
||||
self.tenant_id = tenant_id
|
||||
self.parser_id = parser_id
|
||||
self.parser_config = parser_config or {}
|
||||
self.embd_id = embd_id
|
||||
self.chunk_num = chunk_num
|
||||
self.pagerank = pagerank
|
||||
self.graphrag_task_id = graphrag_task_id
|
||||
self.raptor_task_id = raptor_task_id
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"tenant_id": self.tenant_id,
|
||||
"parser_id": self.parser_id,
|
||||
"parser_config": deepcopy(self.parser_config),
|
||||
"embd_id": self.embd_id,
|
||||
"pagerank": self.pagerank,
|
||||
}
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _set_request_args(monkeypatch, module, args):
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs(args)))
|
||||
|
||||
|
||||
def _patch_json_parser(monkeypatch, module, payload_state, err_state=None):
|
||||
async def _parse_json(*_args, **_kwargs):
|
||||
return deepcopy(payload_state), err_state
|
||||
|
||||
monkeypatch.setattr(module, "validate_and_parse_json_request", _parse_json)
|
||||
|
||||
|
||||
def _load_dataset_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.Request = type("Request", (), {})
|
||||
quart_mod.request = SimpleNamespace(args=_DummyArgs())
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
utils_pkg = ModuleType("api.utils")
|
||||
utils_pkg.__path__ = [str(repo_root / "api" / "utils")]
|
||||
monkeypatch.setitem(sys.modules, "api.utils", utils_pkg)
|
||||
api_pkg.utils = utils_pkg
|
||||
|
||||
apps_pkg = ModuleType("api.apps")
|
||||
apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
|
||||
api_pkg.apps = apps_pkg
|
||||
|
||||
sdk_pkg = ModuleType("api.apps.sdk")
|
||||
sdk_pkg.__path__ = [str(repo_root / "api" / "apps" / "sdk")]
|
||||
monkeypatch.setitem(sys.modules, "api.apps.sdk", sdk_pkg)
|
||||
apps_pkg.sdk = sdk_pkg
|
||||
|
||||
db_pkg = ModuleType("api.db")
|
||||
db_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
|
||||
api_pkg.db = db_pkg
|
||||
|
||||
db_models_mod = ModuleType("api.db.db_models")
|
||||
db_models_mod.File = SimpleNamespace(
|
||||
source_type=_Field("source_type"),
|
||||
id=_Field("id"),
|
||||
type=_Field("type"),
|
||||
name=_Field("name"),
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
document_service_mod = ModuleType("api.db.services.document_service")
|
||||
|
||||
class _StubDocumentService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def remove_document(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_by_kb_id(**_kwargs):
|
||||
return [], 0
|
||||
|
||||
document_service_mod.DocumentService = _StubDocumentService
|
||||
document_service_mod.queue_raptor_o_graphrag_tasks = lambda **_kwargs: "task-queued"
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
|
||||
services_pkg.document_service = document_service_mod
|
||||
|
||||
file2document_service_mod = ModuleType("api.db.services.file2document_service")
|
||||
|
||||
class _StubFile2DocumentService:
|
||||
@staticmethod
|
||||
def get_by_document_id(_doc_id):
|
||||
return [SimpleNamespace(file_id="file-1")]
|
||||
|
||||
@staticmethod
|
||||
def delete_by_document_id(_doc_id):
|
||||
return None
|
||||
|
||||
file2document_service_mod.File2DocumentService = _StubFile2DocumentService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_service_mod)
|
||||
services_pkg.file2document_service = file2document_service_mod
|
||||
|
||||
file_service_mod = ModuleType("api.db.services.file_service")
|
||||
|
||||
class _StubFileService:
|
||||
@staticmethod
|
||||
def filter_delete(_filters):
|
||||
return None
|
||||
|
||||
file_service_mod.FileService = _StubFileService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod)
|
||||
services_pkg.file_service = file_service_mod
|
||||
|
||||
knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service")
|
||||
|
||||
class _StubKnowledgebaseService:
|
||||
@staticmethod
|
||||
def create_with_name(**_kwargs):
|
||||
return True, {"id": "kb-1"}
|
||||
|
||||
@staticmethod
|
||||
def save(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_kb_id):
|
||||
return True, _KB()
|
||||
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_or_none(**_kwargs):
|
||||
return _KB()
|
||||
|
||||
@staticmethod
|
||||
def delete_by_id(_kb_id):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def update_by_id(_kb_id, _payload):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_kb_by_id(_kb_id, _tenant_id):
|
||||
return [SimpleNamespace(id=_kb_id)]
|
||||
|
||||
@staticmethod
|
||||
def get_kb_by_name(_name, _tenant_id):
|
||||
return [SimpleNamespace(name=_name)]
|
||||
|
||||
@staticmethod
|
||||
def get_list(*_args, **_kwargs):
|
||||
return [], 0
|
||||
|
||||
@staticmethod
|
||||
def accessible(_dataset_id, _tenant_id):
|
||||
return True
|
||||
|
||||
knowledgebase_service_mod.KnowledgebaseService = _StubKnowledgebaseService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod)
|
||||
services_pkg.knowledgebase_service = knowledgebase_service_mod
|
||||
|
||||
task_service_mod = ModuleType("api.db.services.task_service")
|
||||
|
||||
class _StubTaskService:
|
||||
@staticmethod
|
||||
def get_by_id(_task_id):
|
||||
return False, None
|
||||
|
||||
task_service_mod.GRAPH_RAPTOR_FAKE_DOC_ID = "fake-doc"
|
||||
task_service_mod.TaskService = _StubTaskService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service_mod)
|
||||
services_pkg.task_service = task_service_mod
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
|
||||
class _StubTenantService:
|
||||
@staticmethod
|
||||
def get_by_id(_tenant_id):
|
||||
return True, SimpleNamespace(embd_id="embd-default")
|
||||
|
||||
@staticmethod
|
||||
def get_joined_tenants_by_user_id(_tenant_id):
|
||||
return [{"tenant_id": "tenant-1"}]
|
||||
|
||||
user_service_mod.TenantService = _StubTenantService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
services_pkg.user_service = user_service_mod
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
|
||||
class _RetCode:
|
||||
SUCCESS = 0
|
||||
ARGUMENT_ERROR = 101
|
||||
DATA_ERROR = 102
|
||||
AUTHENTICATION_ERROR = 108
|
||||
|
||||
class _FileSource:
|
||||
KNOWLEDGEBASE = "knowledgebase"
|
||||
|
||||
class _StatusEnum(Enum):
|
||||
VALID = "valid"
|
||||
|
||||
constants_mod.RetCode = _RetCode
|
||||
constants_mod.FileSource = _FileSource
|
||||
constants_mod.StatusEnum = _StatusEnum
|
||||
constants_mod.PAGERANK_FLD = "pagerank"
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
common_pkg.settings = SimpleNamespace(
|
||||
docStoreConn=SimpleNamespace(
|
||||
delete_idx=lambda *_args, **_kwargs: None,
|
||||
delete=lambda *_args, **_kwargs: None,
|
||||
update=lambda *_args, **_kwargs: None,
|
||||
index_exist=lambda *_args, **_kwargs: False,
|
||||
),
|
||||
retriever=SimpleNamespace(search=lambda *_args, **_kwargs: _AwaitableValue(SimpleNamespace(ids=[], field={}))),
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
def _deep_merge(base, updates):
|
||||
merged = deepcopy(base)
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||
merged[key] = _deep_merge(merged[key], value)
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
def _get_result(*, data=None, message="", code=_RetCode.SUCCESS, total=None):
|
||||
payload = {"code": code, "data": data, "message": message}
|
||||
if total is not None:
|
||||
payload["total"] = total
|
||||
return payload
|
||||
|
||||
def _get_error_argument_result(message=""):
|
||||
return _get_result(code=_RetCode.ARGUMENT_ERROR, message=message)
|
||||
|
||||
def _get_error_data_result(message=""):
|
||||
return _get_result(code=_RetCode.DATA_ERROR, message=message)
|
||||
|
||||
def _get_error_permission_result(message=""):
|
||||
return _get_result(code=_RetCode.AUTHENTICATION_ERROR, message=message)
|
||||
|
||||
def _token_required(func):
|
||||
@functools.wraps(func)
|
||||
async def _async_wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@functools.wraps(func)
|
||||
def _sync_wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return _async_wrapper if asyncio.iscoroutinefunction(func) else _sync_wrapper
|
||||
|
||||
api_utils_mod.deep_merge = _deep_merge
|
||||
api_utils_mod.get_error_argument_result = _get_error_argument_result
|
||||
api_utils_mod.get_error_data_result = _get_error_data_result
|
||||
api_utils_mod.get_error_permission_result = _get_error_permission_result
|
||||
api_utils_mod.get_parser_config = lambda _chunk_method, _unused: {"auto": True}
|
||||
api_utils_mod.get_result = _get_result
|
||||
api_utils_mod.remap_dictionary_keys = lambda data: data
|
||||
api_utils_mod.token_required = _token_required
|
||||
api_utils_mod.verify_embedding_availability = lambda _embd_id, _tenant_id: (True, None)
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
async def _parse_json(*_args, **_kwargs):
|
||||
return {}, None
|
||||
|
||||
def _parse_args(*_args, **_kwargs):
|
||||
return {"name": "", "page": 1, "page_size": 30, "orderby": "create_time", "desc": True}, None
|
||||
|
||||
validation_spec = importlib.util.spec_from_file_location(
|
||||
"api.utils.validation_utils", repo_root / "api" / "utils" / "validation_utils.py"
|
||||
)
|
||||
validation_mod = importlib.util.module_from_spec(validation_spec)
|
||||
monkeypatch.setitem(sys.modules, "api.utils.validation_utils", validation_mod)
|
||||
validation_spec.loader.exec_module(validation_mod)
|
||||
validation_mod.validate_and_parse_json_request = _parse_json
|
||||
validation_mod.validate_and_parse_request_args = _parse_args
|
||||
|
||||
rag_pkg = ModuleType("rag")
|
||||
rag_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
|
||||
|
||||
rag_nlp_pkg = ModuleType("rag.nlp")
|
||||
rag_nlp_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp_pkg)
|
||||
|
||||
search_mod = ModuleType("rag.nlp.search")
|
||||
search_mod.index_name = lambda _tenant_id: "idx"
|
||||
monkeypatch.setitem(sys.modules, "rag.nlp.search", search_mod)
|
||||
rag_nlp_pkg.search = search_mod
|
||||
|
||||
module_name = "test_dataset_sdk_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "sdk" / "dataset.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_route_error_matrix_unit(monkeypatch):
|
||||
module = _load_dataset_module(monkeypatch)
|
||||
req_state = {"name": "kb"}
|
||||
_patch_json_parser(monkeypatch, module, req_state)
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "create_with_name", lambda **_kwargs: (False, {"code": 777, "message": "early"}))
|
||||
res = _run(inspect.unwrap(module.create)("tenant-1"))
|
||||
assert res["code"] == 777, res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "create_with_name", lambda **_kwargs: (True, {"id": "kb-1"}))
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (False, None))
|
||||
res = _run(inspect.unwrap(module.create)("tenant-1"))
|
||||
assert res["message"] == "Tenant not found", res
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (True, SimpleNamespace(embd_id="embd-1")))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "save", lambda **_kwargs: False)
|
||||
res = _run(inspect.unwrap(module.create)("tenant-1"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "save", lambda **_kwargs: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
res = _run(inspect.unwrap(module.create)("tenant-1"))
|
||||
assert "Dataset created failed" in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "save", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("save boom")))
|
||||
res = _run(inspect.unwrap(module.create)("tenant-1"))
|
||||
assert res["message"] == "Database operation failed", res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_delete_route_error_summary_matrix_unit(monkeypatch):
|
||||
module = _load_dataset_module(monkeypatch)
|
||||
req_state = {"ids": ["kb-1"]}
|
||||
_patch_json_parser(monkeypatch, module, req_state)
|
||||
|
||||
kb = _KB(kb_id="kb-1", name="kb-1", tenant_id="tenant-1")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", lambda **_kwargs: kb)
|
||||
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [SimpleNamespace(id="doc-1")])
|
||||
monkeypatch.setattr(module.DocumentService, "remove_document", lambda *_args, **_kwargs: False)
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(delete_idx=lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("drop failed"))))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "delete_by_id", lambda _kb_id: False)
|
||||
res = _run(inspect.unwrap(module.delete)("tenant-1"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert "Successfully deleted 0 datasets" in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(delete_idx=lambda *_args, **_kwargs: None))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "delete_by_id", lambda _kb_id: True)
|
||||
res = _run(inspect.unwrap(module.delete)("tenant-1"))
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert res["data"]["success_count"] == 1, res
|
||||
assert res["data"]["errors"], res
|
||||
|
||||
req_state["ids"] = None
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"query",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(module.OperationalError("db down")),
|
||||
)
|
||||
res = _run(inspect.unwrap(module.delete)("tenant-1"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert res["message"] == "Database operation failed", res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_route_branch_matrix_unit(monkeypatch):
|
||||
module = _load_dataset_module(monkeypatch)
|
||||
req_state = {"name": "new"}
|
||||
_patch_json_parser(monkeypatch, module, req_state)
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", lambda **_kwargs: None)
|
||||
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
|
||||
|
||||
kb = _KB(kb_id="kb-1", name="old", chunk_num=0)
|
||||
|
||||
def _get_or_none_duplicate(**kwargs):
|
||||
if kwargs.get("id"):
|
||||
return kb
|
||||
if kwargs.get("name"):
|
||||
return SimpleNamespace(id="dup")
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", _get_or_none_duplicate)
|
||||
req_state.clear()
|
||||
req_state.update({"name": "new"})
|
||||
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
|
||||
assert "already exists" in res["message"], res
|
||||
|
||||
kb_chunked = _KB(kb_id="kb-1", name="old", chunk_num=2, embd_id="embd-1")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", lambda **kwargs: kb_chunked if kwargs.get("id") else None)
|
||||
req_state.clear()
|
||||
req_state.update({"embd_id": "embd-2"})
|
||||
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
|
||||
assert "chunk_num" in res["message"], res
|
||||
|
||||
kb_rank = _KB(kb_id="kb-1", name="old", pagerank=0)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", lambda **kwargs: kb_rank if kwargs.get("id") else None)
|
||||
req_state.clear()
|
||||
req_state.update({"pagerank": 3})
|
||||
os.environ["DOC_ENGINE"] = "infinity"
|
||||
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
|
||||
assert "doc_engine" in res["message"], res
|
||||
os.environ.pop("DOC_ENGINE", None)
|
||||
|
||||
update_calls = []
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(update=lambda *args, **_kwargs: update_calls.append(args)))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(kb_id="kb-1", pagerank=3)))
|
||||
|
||||
req_state.clear()
|
||||
req_state.update({"pagerank": 3})
|
||||
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert update_calls and update_calls[-1][0] == {"kb_id": "kb-1"}, update_calls
|
||||
|
||||
update_calls.clear()
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_or_none", lambda **kwargs: _KB(kb_id="kb-1", pagerank=3) if kwargs.get("id") else None)
|
||||
req_state.clear()
|
||||
req_state.update({"pagerank": 0})
|
||||
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert update_calls and update_calls[-1][0] == {"exists": module.PAGERANK_FLD}, update_calls
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: False)
|
||||
req_state.clear()
|
||||
req_state.update({"description": "changed"})
|
||||
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
|
||||
assert "Update dataset error" in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
|
||||
assert "Dataset created failed" in res["message"], res
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_or_none",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(module.OperationalError("update down")),
|
||||
)
|
||||
res = _run(inspect.unwrap(module.update)("tenant-1", "kb-1"))
|
||||
assert res["message"] == "Database operation failed", res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_knowledge_graph_delete_kg_matrix_unit(monkeypatch):
|
||||
module = _load_dataset_module(monkeypatch)
|
||||
|
||||
_set_request_args(monkeypatch, module, {"id": "", "name": "", "page": 1, "page_size": 30, "orderby": "create_time", "desc": True})
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"validate_and_parse_request_args",
|
||||
lambda *_args, **_kwargs: ({"name": "", "page": 1, "page_size": 30, "orderby": "create_time", "desc": True}, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_list",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(module.OperationalError("list down")),
|
||||
)
|
||||
res = module.list_datasets("tenant-1")
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert res["message"] == "Database operation failed", res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = _run(inspect.unwrap(module.knowledge_graph)("tenant-1", "kb-1"))
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(tenant_id="tenant-1")))
|
||||
monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx")
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(index_exist=lambda *_args, **_kwargs: False))
|
||||
res = _run(inspect.unwrap(module.knowledge_graph)("tenant-1", "kb-1"))
|
||||
assert res["data"] == {"graph": {}, "mind_map": {}}, res
|
||||
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(index_exist=lambda *_args, **_kwargs: True))
|
||||
|
||||
class _EmptyRetriever:
|
||||
async def search(self, *_args, **_kwargs):
|
||||
return SimpleNamespace(ids=[], field={})
|
||||
|
||||
monkeypatch.setattr(module.settings, "retriever", _EmptyRetriever())
|
||||
res = _run(inspect.unwrap(module.knowledge_graph)("tenant-1", "kb-1"))
|
||||
assert res["data"] == {"graph": {}, "mind_map": {}}, res
|
||||
|
||||
class _BadRetriever:
|
||||
async def search(self, *_args, **_kwargs):
|
||||
return SimpleNamespace(ids=["bad"], field={"bad": {"knowledge_graph_kwd": "graph", "content_with_weight": "{bad"}})
|
||||
|
||||
monkeypatch.setattr(module.settings, "retriever", _BadRetriever())
|
||||
res = _run(inspect.unwrap(module.knowledge_graph)("tenant-1", "kb-1"))
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert res["data"]["graph"] == {}, res
|
||||
|
||||
payload = {
|
||||
"nodes": [{"id": "n2", "pagerank": 2}, {"id": "n1", "pagerank": 5}],
|
||||
"edges": [
|
||||
{"source": "n1", "target": "n2", "weight": 2},
|
||||
{"source": "n1", "target": "n1", "weight": 10},
|
||||
{"source": "n1", "target": "n3", "weight": 9},
|
||||
],
|
||||
}
|
||||
|
||||
class _GoodRetriever:
|
||||
async def search(self, *_args, **_kwargs):
|
||||
return SimpleNamespace(ids=["good"], field={"good": {"knowledge_graph_kwd": "graph", "content_with_weight": json.dumps(payload)}})
|
||||
|
||||
monkeypatch.setattr(module.settings, "retriever", _GoodRetriever())
|
||||
res = _run(inspect.unwrap(module.knowledge_graph)("tenant-1", "kb-1"))
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert len(res["data"]["graph"]["nodes"]) == 2, res
|
||||
assert len(res["data"]["graph"]["edges"]) == 1, res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = inspect.unwrap(module.delete_knowledge_graph)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_run_trace_graphrag_matrix_unit(monkeypatch):
|
||||
module = _load_dataset_module(monkeypatch)
|
||||
|
||||
warnings = []
|
||||
monkeypatch.setattr(module.logging, "warning", lambda msg, *_args, **_kwargs: warnings.append(msg))
|
||||
|
||||
res = inspect.unwrap(module.run_graphrag)("tenant-1", "")
|
||||
assert 'Dataset ID' in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
res = inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")
|
||||
assert "Invalid Dataset ID" in res["message"], res
|
||||
|
||||
stale_kb = _KB(kb_id="kb-1", graphrag_task_id="task-old")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, stale_kb))
|
||||
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None))
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}], 1))
|
||||
monkeypatch.setattr(module, "queue_raptor_o_graphrag_tasks", lambda **_kwargs: "task-new")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
res = inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert any("GraphRAG" in msg for msg in warnings), warnings
|
||||
|
||||
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(progress=0)))
|
||||
res = inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")
|
||||
assert "already running" in res["message"], res
|
||||
|
||||
warnings.clear()
|
||||
queue_calls = {}
|
||||
no_task_kb = _KB(kb_id="kb-1", graphrag_task_id="")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, no_task_kb))
|
||||
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None))
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}, {"id": "doc-2"}], 2))
|
||||
|
||||
def _queue(**kwargs):
|
||||
queue_calls.update(kwargs)
|
||||
return "queued-id"
|
||||
|
||||
monkeypatch.setattr(module, "queue_raptor_o_graphrag_tasks", _queue)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: False)
|
||||
res = inspect.unwrap(module.run_graphrag)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert res["data"]["graphrag_task_id"] == "queued-id", res
|
||||
assert queue_calls["doc_ids"] == ["doc-1", "doc-2"], queue_calls
|
||||
assert any("Cannot save graphrag_task_id" in msg for msg in warnings), warnings
|
||||
|
||||
res = inspect.unwrap(module.trace_graphrag)("tenant-1", "")
|
||||
assert 'Dataset ID' in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1")
|
||||
assert "Invalid Dataset ID" in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(kb_id="kb-1", graphrag_task_id="task-1")))
|
||||
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None))
|
||||
res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert res["data"] == {}, res
|
||||
|
||||
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(to_dict=lambda: {"id": _task_id, "progress": 1})))
|
||||
res = inspect.unwrap(module.trace_graphrag)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert res["data"]["id"] == "task-1", res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_run_trace_raptor_matrix_unit(monkeypatch):
|
||||
module = _load_dataset_module(monkeypatch)
|
||||
|
||||
warnings = []
|
||||
monkeypatch.setattr(module.logging, "warning", lambda msg, *_args, **_kwargs: warnings.append(msg))
|
||||
|
||||
res = inspect.unwrap(module.run_raptor)("tenant-1", "")
|
||||
assert 'Dataset ID' in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
res = inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")
|
||||
assert "Invalid Dataset ID" in res["message"], res
|
||||
|
||||
stale_kb = _KB(kb_id="kb-1", raptor_task_id="task-old")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, stale_kb))
|
||||
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None))
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}], 1))
|
||||
monkeypatch.setattr(module, "queue_raptor_o_graphrag_tasks", lambda **_kwargs: "task-new")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
res = inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert any("RAPTOR" in msg for msg in warnings), warnings
|
||||
|
||||
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(progress=0)))
|
||||
res = inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")
|
||||
assert "already running" in res["message"], res
|
||||
|
||||
warnings.clear()
|
||||
no_task_kb = _KB(kb_id="kb-1", raptor_task_id="")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, no_task_kb))
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda **_kwargs: ([{"id": "doc-1"}], 1))
|
||||
monkeypatch.setattr(module, "queue_raptor_o_graphrag_tasks", lambda **_kwargs: "queued-raptor")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: False)
|
||||
res = inspect.unwrap(module.run_raptor)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert res["data"]["raptor_task_id"] == "queued-raptor", res
|
||||
assert any("Cannot save raptor_task_id" in msg for msg in warnings), warnings
|
||||
|
||||
res = inspect.unwrap(module.trace_raptor)("tenant-1", "")
|
||||
assert 'Dataset ID' in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1")
|
||||
assert "Invalid Dataset ID" in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _KB(kb_id="kb-1", raptor_task_id="task-1")))
|
||||
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (False, None))
|
||||
res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1")
|
||||
assert "RAPTOR Task Not Found" in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.TaskService, "get_by_id", lambda _task_id: (True, SimpleNamespace(to_dict=lambda: {"id": _task_id, "progress": -1})))
|
||||
res = inspect.unwrap(module.trace_raptor)("tenant-1", "kb-1")
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert res["data"]["id"] == "task-1", res
|
||||
@ -0,0 +1,549 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import inspect
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from functools import wraps
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _Args(dict):
|
||||
def get(self, key, default=None):
|
||||
return super().get(key, default)
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))
|
||||
|
||||
|
||||
def _set_request_args(monkeypatch, module, args):
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args)))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_dialog_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = SimpleNamespace(args=_Args())
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="tenant-1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
api_pkg.apps = apps_mod
|
||||
|
||||
db_pkg = ModuleType("api.db")
|
||||
db_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
|
||||
api_pkg.db = db_pkg
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
services_pkg.duplicate_name = lambda _checker, **kwargs: kwargs.get("name", "")
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
dialog_service_mod = ModuleType("api.db.services.dialog_service")
|
||||
|
||||
class _DialogService:
|
||||
model = SimpleNamespace(create_time="create_time")
|
||||
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def save(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def update_by_id(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_id):
|
||||
return True, SimpleNamespace(to_dict=lambda: {"id": _id, "kb_ids": []})
|
||||
|
||||
@staticmethod
|
||||
def get_by_tenant_ids(*_args, **_kwargs):
|
||||
return [], 0
|
||||
|
||||
@staticmethod
|
||||
def update_many_by_id(_payload):
|
||||
return True
|
||||
|
||||
dialog_service_mod.DialogService = _DialogService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.dialog_service", dialog_service_mod)
|
||||
|
||||
tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
|
||||
class _TenantLLMService:
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(embd_id):
|
||||
return embd_id.split("@")
|
||||
|
||||
tenant_llm_service_mod.TenantLLMService = _TenantLLMService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)
|
||||
|
||||
knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service")
|
||||
|
||||
class _KnowledgebaseService:
|
||||
@staticmethod
|
||||
def get_by_ids(_ids):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_id):
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
knowledgebase_service_mod.KnowledgebaseService = _KnowledgebaseService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod)
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
|
||||
class _TenantService:
|
||||
@staticmethod
|
||||
def get_by_id(_id):
|
||||
return True, SimpleNamespace(llm_id="llm-default")
|
||||
|
||||
class _UserTenantService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return [SimpleNamespace(tenant_id="tenant-1")]
|
||||
|
||||
user_service_mod.TenantService = _TenantService
|
||||
user_service_mod.UserTenantService = _UserTenantService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
from common.constants import RetCode
|
||||
|
||||
async def _default_request_json():
|
||||
return {}
|
||||
|
||||
def _get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing!"):
|
||||
return {"code": code, "message": message}
|
||||
|
||||
def _get_json_result(code=RetCode.SUCCESS, message="success", data=None):
|
||||
return {"code": code, "message": message, "data": data}
|
||||
|
||||
def _server_error_response(error):
|
||||
return {"code": RetCode.EXCEPTION_ERROR, "message": repr(error)}
|
||||
|
||||
def _validate_request(*_args, **_kwargs):
|
||||
def _decorator(func):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def _wrapped(*func_args, **func_kwargs):
|
||||
return await func(*func_args, **func_kwargs)
|
||||
|
||||
return _wrapped
|
||||
|
||||
@wraps(func)
|
||||
def _wrapped(*func_args, **func_kwargs):
|
||||
return func(*func_args, **func_kwargs)
|
||||
|
||||
return _wrapped
|
||||
|
||||
return _decorator
|
||||
|
||||
api_utils_mod.get_request_json = _default_request_json
|
||||
api_utils_mod.get_data_error_result = _get_data_error_result
|
||||
api_utils_mod.get_json_result = _get_json_result
|
||||
api_utils_mod.server_error_response = _server_error_response
|
||||
api_utils_mod.validate_request = _validate_request
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
module_name = "test_dialog_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "dialog_app.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_set_dialog_branch_matrix_unit(monkeypatch):
|
||||
module = _load_dialog_module(monkeypatch)
|
||||
handler = inspect.unwrap(module.set_dialog)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": 1, "prompt_config": {"system": "", "parameters": []}})
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Dialog name must be string."
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": " ", "prompt_config": {"system": "", "parameters": []}})
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Dialog name can't be empty."
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "a" * 256, "prompt_config": {"system": "", "parameters": []}})
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Dialog name length is 256 which is larger than 255"
|
||||
|
||||
captured = {}
|
||||
|
||||
def _dup_name(checker, **kwargs):
|
||||
assert checker(name=kwargs["name"]) is True
|
||||
return kwargs["name"] + " (1)"
|
||||
|
||||
monkeypatch.setattr(module, "duplicate_name", _dup_name)
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(name="new dialog")])
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x")))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [SimpleNamespace(embd_id="embd-a@builtin")])
|
||||
monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda embd_id: embd_id.split("@"))
|
||||
monkeypatch.setattr(module.DialogService, "save", lambda **kwargs: captured.update(kwargs) or False)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"name": "New Dialog",
|
||||
"kb_ids": ["kb-1"],
|
||||
"prompt_config": {"system": "Use {knowledge}", "parameters": []},
|
||||
},
|
||||
)
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Fail to new a dialog!"
|
||||
assert captured["name"] == "New Dialog (1)"
|
||||
assert captured["prompt_config"]["parameters"] == [{"key": "knowledge", "optional": False}]
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"dialog_id": "dialog-1",
|
||||
"name": "Update",
|
||||
"kb_ids": [],
|
||||
"prompt_config": {
|
||||
"system": "Use {knowledge}",
|
||||
"parameters": [{"key": "knowledge", "optional": True}],
|
||||
},
|
||||
},
|
||||
)
|
||||
res = _run(handler())
|
||||
assert "Please remove `{knowledge}` in system prompt" in res["message"]
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{"name": "demo", "prompt_config": {"system": "hello", "parameters": [{"key": "must", "optional": False}]}},
|
||||
)
|
||||
res = _run(handler())
|
||||
assert "Parameter 'must' is not used" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (False, None))
|
||||
_set_request_json(monkeypatch, module, {"name": "demo", "prompt_config": {"system": "hello", "parameters": []}})
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _id: (True, SimpleNamespace(llm_id="llm-x")))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue(
|
||||
{
|
||||
"name": "demo",
|
||||
"kb_ids": ["kb-1", "kb-2"],
|
||||
"prompt_config": {"system": "hello", "parameters": []},
|
||||
}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_by_ids",
|
||||
lambda _ids: [SimpleNamespace(embd_id="embd-a@f1"), SimpleNamespace(embd_id="embd-b@f2")],
|
||||
)
|
||||
monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda embd_id: embd_id.split("@"))
|
||||
res = _run(handler())
|
||||
assert "Datasets use different embedding models" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue(
|
||||
{
|
||||
"name": "optional-param-dialog",
|
||||
"prompt_config": {"system": "hello", "parameters": [{"key": "ignored", "optional": True}]},
|
||||
}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [])
|
||||
monkeypatch.setattr(module.DialogService, "save", lambda **_kwargs: False)
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Fail to new a dialog!"
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_ids", lambda _ids: [])
|
||||
monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: False)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"dialog_id": "dialog-1",
|
||||
"kb_names": ["legacy"],
|
||||
"name": "rename",
|
||||
"prompt_config": {"system": "hello", "parameters": []},
|
||||
},
|
||||
)
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Dialog not found!"
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None))
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"dialog_id": "dialog-1",
|
||||
"name": "rename",
|
||||
"prompt_config": {"system": "hello", "parameters": []},
|
||||
},
|
||||
)
|
||||
res = _run(handler())
|
||||
assert res["message"] == "Fail to update a dialog!"
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, SimpleNamespace(to_dict=lambda: {"id": _id, "kb_ids": ["kb-1"]})))
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_by_id",
|
||||
lambda _id: (True, SimpleNamespace(status=module.StatusEnum.VALID.value, name="KB One")),
|
||||
)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"dialog_id": "dialog-1",
|
||||
"kb_names": ["legacy"],
|
||||
"name": "new-name",
|
||||
"prompt_config": {"system": "hello", "parameters": []},
|
||||
},
|
||||
)
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["name"] == "new-name"
|
||||
assert res["data"]["kb_names"] == ["KB One"]
|
||||
|
||||
def _raise_tenant(_id):
|
||||
raise RuntimeError("set boom")
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", _raise_tenant)
|
||||
_set_request_json(monkeypatch, module, {"name": "demo", "prompt_config": {"system": "hello", "parameters": []}})
|
||||
res = _run(handler())
|
||||
assert "set boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_get_kb_names_and_list_dialogs_exception_matrix_unit(monkeypatch):
|
||||
module = _load_dialog_module(monkeypatch)
|
||||
get_handler = inspect.unwrap(module.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.DialogService,
|
||||
"get_by_id",
|
||||
lambda _id: (True, SimpleNamespace(to_dict=lambda: {"id": _id, "kb_ids": ["kb-1", "kb-2"]})),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_by_id",
|
||||
lambda kid: (
|
||||
(True, SimpleNamespace(status=module.StatusEnum.VALID.value, name="KB-1"))
|
||||
if kid == "kb-1"
|
||||
else (False, None)
|
||||
),
|
||||
)
|
||||
_set_request_args(monkeypatch, module, {"dialog_id": "dialog-1"})
|
||||
res = get_handler()
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["kb_ids"] == ["kb-1"]
|
||||
assert res["data"]["kb_names"] == ["KB-1"]
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None))
|
||||
_set_request_args(monkeypatch, module, {"dialog_id": "dialog-missing"})
|
||||
res = get_handler()
|
||||
assert res["message"] == "Dialog not found!"
|
||||
|
||||
def _raise_get(_id):
|
||||
raise RuntimeError("get boom")
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", _raise_get)
|
||||
_set_request_args(monkeypatch, module, {"dialog_id": "dialog-1"})
|
||||
res = get_handler()
|
||||
assert "get boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_by_id",
|
||||
lambda kid: (
|
||||
(True, SimpleNamespace(status=module.StatusEnum.VALID.value, name=f"KB-{kid}"))
|
||||
if kid.startswith("ok")
|
||||
else (True, SimpleNamespace(status=module.StatusEnum.INVALID.value, name=f"BAD-{kid}"))
|
||||
),
|
||||
)
|
||||
ids, names = module.get_kb_names(["ok-1", "bad-1", "ok-2"])
|
||||
assert ids == ["ok-1", "ok-2"]
|
||||
assert names == ["KB-ok-1", "KB-ok-2"]
|
||||
|
||||
def _raise_list(**_kwargs):
|
||||
raise RuntimeError("list boom")
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "query", _raise_list)
|
||||
res = module.list_dialogs()
|
||||
assert "list boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_dialogs_next_owner_desc_and_pagination_matrix_unit(monkeypatch):
|
||||
module = _load_dialog_module(monkeypatch)
|
||||
handler = inspect.unwrap(module.list_dialogs_next)
|
||||
|
||||
calls = []
|
||||
|
||||
def _get_by_tenant_ids(tenants, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id):
|
||||
calls.append(
|
||||
{
|
||||
"tenants": tenants,
|
||||
"user_id": user_id,
|
||||
"page_number": page_number,
|
||||
"items_per_page": items_per_page,
|
||||
"orderby": orderby,
|
||||
"desc": desc,
|
||||
"keywords": keywords,
|
||||
"parser_id": parser_id,
|
||||
}
|
||||
)
|
||||
if tenants:
|
||||
return (
|
||||
[
|
||||
{"id": "dialog-1", "tenant_id": "tenant-a"},
|
||||
{"id": "dialog-2", "tenant_id": "tenant-x"},
|
||||
{"id": "dialog-3", "tenant_id": "tenant-b"},
|
||||
],
|
||||
3,
|
||||
)
|
||||
return ([{"id": "dialog-0", "tenant_id": "tenant-1"}], 1)
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids)
|
||||
|
||||
_set_request_args(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"keywords": "k",
|
||||
"page": "1",
|
||||
"page_size": "2",
|
||||
"parser_id": "parser-x",
|
||||
"orderby": "create_time",
|
||||
"desc": "false",
|
||||
},
|
||||
)
|
||||
_set_request_json(monkeypatch, module, {"owner_ids": []})
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 1
|
||||
assert calls[-1]["tenants"] == []
|
||||
assert calls[-1]["desc"] is False
|
||||
|
||||
_set_request_args(monkeypatch, module, {"page": "2", "page_size": "1"})
|
||||
_set_request_json(monkeypatch, module, {"owner_ids": ["tenant-a", "tenant-b"]})
|
||||
res = _run(handler())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 2
|
||||
assert res["data"]["dialogs"] == [{"id": "dialog-3", "tenant_id": "tenant-b"}]
|
||||
assert calls[-1]["page_number"] == 0
|
||||
assert calls[-1]["items_per_page"] == 0
|
||||
assert calls[-1]["desc"] is True
|
||||
|
||||
def _raise_next(*_args, **_kwargs):
|
||||
raise RuntimeError("next boom")
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _raise_next)
|
||||
_set_request_args(monkeypatch, module, {"page": "1", "page_size": "1"})
|
||||
_set_request_json(monkeypatch, module, {"owner_ids": []})
|
||||
res = _run(handler())
|
||||
assert "next boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_rm_permission_and_exception_matrix_unit(monkeypatch):
|
||||
module = _load_dialog_module(monkeypatch)
|
||||
handler = inspect.unwrap(module.rm)
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-a")])
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
|
||||
_set_request_json(monkeypatch, module, {"dialog_ids": ["dialog-1"]})
|
||||
res = _run(handler())
|
||||
assert res["code"] == module.RetCode.OPERATING_ERROR
|
||||
assert "Only owner of dialog authorized for this operation." in res["message"]
|
||||
|
||||
def _raise_query(**_kwargs):
|
||||
raise RuntimeError("rm boom")
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "query", _raise_query)
|
||||
_set_request_json(monkeypatch, module, {"dialog_ids": ["dialog-1"]})
|
||||
res = _run(handler())
|
||||
assert "rm boom" in res["message"]
|
||||
@ -250,6 +250,31 @@ def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
class _DummyArgs:
|
||||
def __init__(self, args=None):
|
||||
self._args = args or {}
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._args.get(key, default)
|
||||
|
||||
def getlist(self, key):
|
||||
value = self._args.get(key, [])
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return [value]
|
||||
|
||||
|
||||
class _DummyRequest:
|
||||
def __init__(self, args=None):
|
||||
self.args = _DummyArgs(args)
|
||||
|
||||
|
||||
class _DummyResponse:
|
||||
def __init__(self, data=None):
|
||||
self.data = data
|
||||
self.headers = {}
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestDocumentMetadataUnit:
|
||||
def _allow_kb(self, module, monkeypatch, kb_id="kb1", tenant_id="tenant1"):
|
||||
@ -411,3 +436,546 @@ class TestDocumentMetadataUnit:
|
||||
res = _run(module.metadata_update.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["matched_docs"] == 1
|
||||
|
||||
def test_metadata_update_invalid_delete_item_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "doc_ids": ["doc1"], "updates": [], "deletes": [{}]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.metadata_update.__wrapped__())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert "Each delete requires key." in res["message"]
|
||||
|
||||
def test_update_metadata_setting_authorization_and_refetch_not_found_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
async def fake_request_json():
|
||||
return {"doc_id": "doc1", "metadata": {"author": "alice"}}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.update_metadata_setting.__wrapped__())
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
|
||||
assert "No authorization." in res["message"]
|
||||
|
||||
doc = SimpleNamespace(id="doc1", to_dict=lambda: {"id": "doc1", "parser_config": {}})
|
||||
state = {"count": 0}
|
||||
|
||||
def fake_get_by_id(_doc_id):
|
||||
state["count"] += 1
|
||||
if state["count"] == 1:
|
||||
return True, doc
|
||||
return False, None
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", fake_get_by_id)
|
||||
monkeypatch.setattr(module.DocumentService, "update_parser_config", lambda *_args, **_kwargs: True)
|
||||
res = _run(module.update_metadata_setting.__wrapped__())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "Document not found!" in res["message"]
|
||||
|
||||
def test_thumbnails_missing_ids_rewrite_and_exception_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(args={}))
|
||||
res = module.thumbnails()
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert 'Lack of "Document ID"' in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(args={"doc_ids": ["doc1", "doc2"]}))
|
||||
monkeypatch.setattr(
|
||||
module.DocumentService,
|
||||
"get_thumbnails",
|
||||
lambda _doc_ids: [
|
||||
{"id": "doc1", "kb_id": "kb1", "thumbnail": "thumb.jpg"},
|
||||
{"id": "doc2", "kb_id": "kb1", "thumbnail": f"{module.IMG_BASE64_PREFIX}blob"},
|
||||
],
|
||||
)
|
||||
res = module.thumbnails()
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["doc1"] == "/v1/document/image/kb1-thumb.jpg"
|
||||
assert res["data"]["doc2"] == f"{module.IMG_BASE64_PREFIX}blob"
|
||||
|
||||
def raise_error(*_args, **_kwargs):
|
||||
raise RuntimeError("thumb boom")
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_thumbnails", raise_error)
|
||||
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
|
||||
res = module.thumbnails()
|
||||
assert res["code"] == 500
|
||||
assert "thumb boom" in res["message"]
|
||||
|
||||
def test_change_status_partial_failure_matrix_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
calls = {"docstore_update": []}
|
||||
doc_ids = ["unauth", "missing_doc", "missing_kb", "update_fail", "docstore_3022", "docstore_generic", "outer_exc"]
|
||||
|
||||
async def fake_request_json():
|
||||
return {"doc_ids": doc_ids, "status": "1"}
|
||||
|
||||
def fake_accessible(doc_id, _uid):
|
||||
return doc_id != "unauth"
|
||||
|
||||
def fake_get_by_id(doc_id):
|
||||
if doc_id == "missing_doc":
|
||||
return False, None
|
||||
if doc_id == "outer_exc":
|
||||
raise RuntimeError("explode")
|
||||
kb_id = "kb_missing" if doc_id == "missing_kb" else "kb1"
|
||||
chunk_num = 1 if doc_id in {"docstore_3022", "docstore_generic"} else 0
|
||||
doc = SimpleNamespace(id=doc_id, kb_id=kb_id, status="0", chunk_num=chunk_num)
|
||||
return True, doc
|
||||
|
||||
def fake_get_kb(kb_id):
|
||||
if kb_id == "kb_missing":
|
||||
return False, None
|
||||
return True, SimpleNamespace(tenant_id="tenant1")
|
||||
|
||||
def fake_update_by_id(doc_id, _payload):
|
||||
return doc_id != "update_fail"
|
||||
|
||||
class _DocStore:
|
||||
def update(self, where, _payload, _index_name, _kb_id):
|
||||
calls["docstore_update"].append(where["doc_id"])
|
||||
if where["doc_id"] == "docstore_3022":
|
||||
raise RuntimeError("3022 table missing")
|
||||
if where["doc_id"] == "docstore_generic":
|
||||
raise RuntimeError("doc store down")
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", fake_accessible)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", fake_get_by_id)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda kb_id: fake_get_kb(kb_id))
|
||||
monkeypatch.setattr(module.DocumentService, "update_by_id", fake_update_by_id)
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", _DocStore())
|
||||
monkeypatch.setattr(module.search, "index_name", lambda tenant_id: f"idx_{tenant_id}")
|
||||
|
||||
res = _run(module.change_status.__wrapped__())
|
||||
assert res["code"] == module.RetCode.SERVER_ERROR
|
||||
assert res["message"] == "Partial failure"
|
||||
assert res["data"]["unauth"]["error"] == "No authorization."
|
||||
assert res["data"]["missing_doc"]["error"] == "No authorization."
|
||||
assert res["data"]["missing_kb"]["error"] == "Can't find this dataset!"
|
||||
assert res["data"]["update_fail"]["error"] == "Database error (Document update)!"
|
||||
assert res["data"]["docstore_3022"]["error"] == "Document store table missing."
|
||||
assert "Document store update failed:" in res["data"]["docstore_generic"]["error"]
|
||||
assert "Internal server error: explode" == res["data"]["outer_exc"]["error"]
|
||||
assert calls["docstore_update"] == ["docstore_3022", "docstore_generic"]
|
||||
|
||||
def test_change_status_invalid_status_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
async def fake_request_json():
|
||||
return {"doc_ids": ["doc1"], "status": "2"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.change_status.__wrapped__())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert '"Status" must be either 0 or 1!' in res["message"]
|
||||
|
||||
def test_change_status_all_success_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
async def fake_request_json():
|
||||
return {"doc_ids": ["doc1"], "status": "1"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id="doc1", kb_id="kb1", status="0", chunk_num=0)))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant1")))
|
||||
monkeypatch.setattr(module.DocumentService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
res = _run(module.change_status.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["doc1"]["status"] == "1"
|
||||
|
||||
def test_rename_branch_matrix_and_exception_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
file_updates = []
|
||||
es_updates = []
|
||||
|
||||
async def fake_thread_pool_exec(func, *_args, **_kwargs):
|
||||
return func()
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant1")
|
||||
monkeypatch.setattr(module.rag_tokenizer, "tokenize", lambda _name: ["token"])
|
||||
monkeypatch.setattr(module.rag_tokenizer, "fine_grained_tokenize", lambda _tokens: ["fine"])
|
||||
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
|
||||
|
||||
class _DocStore:
|
||||
def index_exist(self, _index_name, _kb_id):
|
||||
return True
|
||||
|
||||
def update(self, where, payload, _index_name, _kb_id):
|
||||
es_updates.append((where, payload))
|
||||
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", _DocStore())
|
||||
monkeypatch.setattr(module.search, "index_name", lambda tenant_id: f"idx_{tenant_id}")
|
||||
|
||||
def set_req(name):
|
||||
async def fake_request_json():
|
||||
return {"doc_id": "doc1", "name": name}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
|
||||
set_req("renamed.txt")
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.rename.__wrapped__())
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
res = _run(module.rename.__wrapped__())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "Document not found!" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id="doc1", name="origin.txt", kb_id="kb1")))
|
||||
set_req("renamed.pdf")
|
||||
res = _run(module.rename.__wrapped__())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert "extension" in res["message"]
|
||||
|
||||
too_long = "a" * (module.FILE_NAME_LEN_LIMIT + 1) + ".txt"
|
||||
set_req(too_long)
|
||||
res = _run(module.rename.__wrapped__())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert "bytes or less" in res["message"]
|
||||
|
||||
set_req("dup.txt")
|
||||
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [SimpleNamespace(name="dup.txt")])
|
||||
res = _run(module.rename.__wrapped__())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "Duplicated document name" in res["message"]
|
||||
|
||||
set_req("ok.txt")
|
||||
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [])
|
||||
monkeypatch.setattr(module.DocumentService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.File2DocumentService, "get_by_document_id", lambda _doc_id: [SimpleNamespace(file_id="file1")])
|
||||
monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (True, SimpleNamespace(id="file1")))
|
||||
monkeypatch.setattr(module.FileService, "update_by_id", lambda file_id, payload: file_updates.append((file_id, payload)))
|
||||
res = _run(module.rename.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert file_updates == [("file1", {"name": "ok.txt"})]
|
||||
assert es_updates[0][0] == {"doc_id": "doc1"}
|
||||
assert es_updates[0][1]["docnm_kwd"] == "ok.txt"
|
||||
assert es_updates[0][1]["title_tks"] == ["token"]
|
||||
assert es_updates[0][1]["title_sm_tks"] == ["fine"]
|
||||
|
||||
def raise_db_error(*_args, **_kwargs):
|
||||
raise RuntimeError("rename boom")
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "update_by_id", raise_db_error)
|
||||
res = _run(module.rename.__wrapped__())
|
||||
assert res["code"] == 500
|
||||
assert "rename boom" in res["message"]
|
||||
|
||||
def test_get_route_not_found_success_and_exception_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
res = _run(module.get("doc1"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "Document not found!" in res["message"]
|
||||
|
||||
async def fake_thread_pool_exec(*_args, **_kwargs):
|
||||
return b"blob-data"
|
||||
|
||||
async def fake_make_response(data):
|
||||
return _DummyResponse(data)
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(name="image.abc", type=module.FileType.VISUAL.value)))
|
||||
monkeypatch.setattr(module.File2DocumentService, "get_storage_address", lambda **_kwargs: ("bucket", "name"))
|
||||
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"blob-data"))
|
||||
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
|
||||
monkeypatch.setattr(module, "make_response", fake_make_response)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"apply_safe_file_response_headers",
|
||||
lambda response, content_type, extension: response.headers.update({"content_type": content_type, "extension": extension}),
|
||||
)
|
||||
res = _run(module.get("doc1"))
|
||||
assert isinstance(res, _DummyResponse)
|
||||
assert res.data == b"blob-data"
|
||||
assert res.headers["content_type"] == "image/abc"
|
||||
assert res.headers["extension"] == "abc"
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (_ for _ in ()).throw(RuntimeError("get boom")))
|
||||
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
|
||||
res = _run(module.get("doc1"))
|
||||
assert res["code"] == 500
|
||||
assert "get boom" in res["message"]
|
||||
|
||||
def test_download_attachment_success_and_exception_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(args={"ext": "abc"}))
|
||||
|
||||
async def fake_thread_pool_exec(*_args, **_kwargs):
|
||||
return b"attachment"
|
||||
|
||||
async def fake_make_response(data):
|
||||
return _DummyResponse(data)
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
|
||||
monkeypatch.setattr(module, "make_response", fake_make_response)
|
||||
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"attachment"))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"apply_safe_file_response_headers",
|
||||
lambda response, content_type, extension: response.headers.update({"content_type": content_type, "extension": extension}),
|
||||
)
|
||||
res = _run(module.download_attachment("att1"))
|
||||
assert isinstance(res, _DummyResponse)
|
||||
assert res.data == b"attachment"
|
||||
assert res.headers["content_type"] == "application/abc"
|
||||
assert res.headers["extension"] == "abc"
|
||||
|
||||
async def raise_error(*_args, **_kwargs):
|
||||
raise RuntimeError("download boom")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", raise_error)
|
||||
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
|
||||
res = _run(module.download_attachment("att1"))
|
||||
assert res["code"] == 500
|
||||
assert "download boom" in res["message"]
|
||||
|
||||
def test_change_parser_guards_and_reset_update_failure_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
|
||||
|
||||
async def req_auth_fail():
|
||||
return {"doc_id": "doc1", "parser_id": "naive", "pipeline_id": "pipe2"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", req_auth_fail)
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "Document not found!" in res["message"]
|
||||
|
||||
async def req_same_pipeline():
|
||||
return {"doc_id": "doc1", "parser_id": "naive", "pipeline_id": "pipe1"}
|
||||
|
||||
doc_same = SimpleNamespace(
|
||||
id="doc1",
|
||||
pipeline_id="pipe1",
|
||||
parser_id="naive",
|
||||
parser_config={"k": "v"},
|
||||
token_num=0,
|
||||
chunk_num=0,
|
||||
process_duration=0,
|
||||
kb_id="kb1",
|
||||
type="doc",
|
||||
name="doc.txt",
|
||||
)
|
||||
monkeypatch.setattr(module, "get_request_json", req_same_pipeline)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_same))
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
|
||||
calls = []
|
||||
|
||||
async def req_pipeline_change():
|
||||
return {"doc_id": "doc1", "parser_id": "naive", "pipeline_id": "pipe2"}
|
||||
|
||||
doc = SimpleNamespace(
|
||||
id="doc1",
|
||||
pipeline_id="pipe1",
|
||||
parser_id="naive",
|
||||
parser_config={},
|
||||
token_num=0,
|
||||
chunk_num=0,
|
||||
process_duration=0,
|
||||
kb_id="kb1",
|
||||
type="doc",
|
||||
name="doc.txt",
|
||||
)
|
||||
|
||||
def fake_update_by_id(doc_id, payload):
|
||||
calls.append((doc_id, payload))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", req_pipeline_change)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc))
|
||||
monkeypatch.setattr(module.DocumentService, "update_by_id", fake_update_by_id)
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert calls[0][1] == {"pipeline_id": "pipe2"}
|
||||
assert calls[1][1]["run"] == module.TaskStatus.UNSTART.value
|
||||
|
||||
doc.token_num = 3
|
||||
doc.chunk_num = 2
|
||||
doc.process_duration = 9
|
||||
monkeypatch.setattr(module.DocumentService, "increment_chunk_num", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "increment_chunk_num", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: None)
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
|
||||
side_effects = {"img": [], "delete": []}
|
||||
|
||||
class _DocStore:
|
||||
def index_exist(self, _idx, _kb_id):
|
||||
return True
|
||||
|
||||
def delete(self, where, _idx, kb_id):
|
||||
side_effects["delete"].append((where["doc_id"], kb_id))
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant1")
|
||||
monkeypatch.setattr(module.DocumentService, "delete_chunk_images", lambda _doc, _tenant: side_effects["img"].append((_doc.id, _tenant)))
|
||||
monkeypatch.setattr(module.search, "index_name", lambda tenant_id: f"idx_{tenant_id}")
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", _DocStore())
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert ("doc1", "tenant1") in side_effects["img"]
|
||||
assert ("doc1", "kb1") in side_effects["delete"]
|
||||
|
||||
async def req_same_parser_with_cfg():
|
||||
return {"doc_id": "doc1", "parser_id": "naive", "parser_config": {"a": 1}}
|
||||
|
||||
doc_same_parser = SimpleNamespace(
|
||||
id="doc1",
|
||||
pipeline_id="pipe1",
|
||||
parser_id="naive",
|
||||
parser_config={"a": 1},
|
||||
token_num=0,
|
||||
chunk_num=0,
|
||||
process_duration=0,
|
||||
kb_id="kb1",
|
||||
type="doc",
|
||||
name="doc.txt",
|
||||
)
|
||||
monkeypatch.setattr(module, "get_request_json", req_same_parser_with_cfg)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_same_parser))
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
|
||||
async def req_same_parser_no_cfg():
|
||||
return {"doc_id": "doc1", "parser_id": "naive"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", req_same_parser_no_cfg)
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
|
||||
parser_cfg_updates = []
|
||||
|
||||
async def req_parser_update():
|
||||
return {"doc_id": "doc1", "parser_id": "paper", "pipeline_id": "", "parser_config": {"beta": True}}
|
||||
|
||||
doc_parser_update = SimpleNamespace(
|
||||
id="doc1",
|
||||
pipeline_id="pipe1",
|
||||
parser_id="naive",
|
||||
parser_config={"alpha": 1},
|
||||
token_num=0,
|
||||
chunk_num=0,
|
||||
process_duration=0,
|
||||
kb_id="kb1",
|
||||
type="doc",
|
||||
name="doc.txt",
|
||||
)
|
||||
monkeypatch.setattr(module, "get_request_json", req_parser_update)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_parser_update))
|
||||
monkeypatch.setattr(module.DocumentService, "update_parser_config", lambda doc_id, cfg: parser_cfg_updates.append((doc_id, cfg)))
|
||||
monkeypatch.setattr(module.DocumentService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert parser_cfg_updates == [("doc1", {"beta": True})]
|
||||
|
||||
def raise_parser_config(*_args, **_kwargs):
|
||||
raise RuntimeError("parser boom")
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "update_parser_config", raise_parser_config)
|
||||
res = _run(module.change_parser.__wrapped__())
|
||||
assert res["code"] == 500
|
||||
assert "parser boom" in res["message"]
|
||||
|
||||
def test_get_image_success_and_exception_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
class _Headers(dict):
|
||||
def set(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
class _ImageResponse:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.headers = _Headers()
|
||||
|
||||
async def fake_thread_pool_exec(*_args, **_kwargs):
|
||||
return b"image-bytes"
|
||||
|
||||
async def fake_make_response(data):
|
||||
return _ImageResponse(data)
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
|
||||
monkeypatch.setattr(module, "make_response", fake_make_response)
|
||||
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"image-bytes"))
|
||||
res = _run(module.get_image("bucket-name"))
|
||||
assert isinstance(res, _ImageResponse)
|
||||
assert res.data == b"image-bytes"
|
||||
assert res.headers["Content-Type"] == "image/JPEG"
|
||||
|
||||
async def raise_error(*_args, **_kwargs):
|
||||
raise RuntimeError("image boom")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", raise_error)
|
||||
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
|
||||
res = _run(module.get_image("bucket-name"))
|
||||
assert res["code"] == 500
|
||||
assert "image boom" in res["message"]
|
||||
|
||||
def test_set_meta_validation_and_persistence_matrix_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
def set_req(payload):
|
||||
async def fake_request_json():
|
||||
return payload
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
|
||||
set_req({"doc_id": "doc1", "meta": "{}"})
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.set_meta.__wrapped__())
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
|
||||
set_req({"doc_id": "doc1", "meta": "[]"})
|
||||
res = _run(module.set_meta.__wrapped__())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert "Only dictionary type supported." in res["message"]
|
||||
|
||||
set_req({"doc_id": "doc1", "meta": '{"tags":[{"x":1}]}'})
|
||||
res = _run(module.set_meta.__wrapped__())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert "The type is not supported in list" in res["message"]
|
||||
|
||||
set_req({"doc_id": "doc1", "meta": '{"obj":{"x":1}}'})
|
||||
res = _run(module.set_meta.__wrapped__())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert "The type is not supported" in res["message"]
|
||||
|
||||
set_req({"doc_id": "doc1", "meta": "{"})
|
||||
res = _run(module.set_meta.__wrapped__())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert "Json syntax error:" in res["message"]
|
||||
|
||||
set_req({"doc_id": "doc1", "meta": '{"author":"alice"}'})
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
res = _run(module.set_meta.__wrapped__())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "Document not found!" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id="doc1")))
|
||||
monkeypatch.setattr(module.DocMetadataService, "update_document_metadata", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.set_meta.__wrapped__())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "Database error (meta updates)!" in res["message"]
|
||||
|
||||
@ -13,7 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from common import bulk_upload_documents, list_documents, parse_documents
|
||||
@ -22,6 +24,10 @@ from libs.auth import RAGFlowWebApiAuth
|
||||
from utils import wait_for
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@wait_for(30, 1, "Document parsing timeout")
|
||||
def condition(_auth, _kb_id, _document_ids=None):
|
||||
res = list_documents(_auth, {"kb_id": _kb_id})
|
||||
@ -194,6 +200,94 @@ def test_concurrent_parse(WebApiAuth, add_dataset_func, tmp_path):
|
||||
validate_document_parse_done(WebApiAuth, kb_id, document_ids)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestDocumentsParseUnit:
|
||||
def test_run_branch_matrix_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
calls = {"clear": [], "filter_delete": [], "docstore_delete": [], "cancel": [], "run": []}
|
||||
|
||||
async def fake_thread_pool_exec(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
|
||||
monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)})
|
||||
monkeypatch.setattr(module.search, "index_name", lambda tenant_id: f"idx_{tenant_id}")
|
||||
monkeypatch.setattr(module, "cancel_all_task_of", lambda doc_id: calls["cancel"].append(doc_id))
|
||||
|
||||
class _DocStore:
|
||||
def index_exist(self, _index_name, _kb_id):
|
||||
return True
|
||||
|
||||
def delete(self, where, _index_name, _kb_id):
|
||||
calls["docstore_delete"].append(where["doc_id"])
|
||||
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", _DocStore())
|
||||
|
||||
async def set_request(payload):
|
||||
return payload
|
||||
|
||||
def apply_request(payload):
|
||||
async def fake_request_json():
|
||||
return await set_request(payload)
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
|
||||
apply_request({"doc_ids": ["doc1"], "run": module.TaskStatus.RUNNING.value})
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.run.__wrapped__())
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: None)
|
||||
res = _run(module.run.__wrapped__())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "Tenant not found!" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant1")
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
res = _run(module.run.__wrapped__())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "Document not found!" in res["message"]
|
||||
|
||||
apply_request({"doc_ids": ["doc1"], "run": module.TaskStatus.CANCEL.value})
|
||||
doc_cancel = SimpleNamespace(id="doc1", run=module.TaskStatus.DONE.value, kb_id="kb1", parser_config={}, to_dict=lambda: {"id": "doc1"})
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_cancel))
|
||||
monkeypatch.setattr(module.TaskService, "query", lambda **_kwargs: [SimpleNamespace(progress=1)])
|
||||
res = _run(module.run.__wrapped__())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "Cannot cancel a task that is not in RUNNING status" in res["message"]
|
||||
|
||||
apply_request({"doc_ids": ["doc1"], "run": module.TaskStatus.RUNNING.value, "delete": True})
|
||||
doc_rerun = SimpleNamespace(id="doc1", run=module.TaskStatus.DONE.value, kb_id="kb1", parser_config={}, to_dict=lambda: {"id": "doc1"})
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, doc_rerun))
|
||||
monkeypatch.setattr(module.DocumentService, "clear_chunk_num_when_rerun", lambda doc_id: calls["clear"].append(doc_id))
|
||||
monkeypatch.setattr(module.TaskService, "filter_delete", lambda _filters: calls["filter_delete"].append(True))
|
||||
monkeypatch.setattr(module.DocumentService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.DocumentService, "run", lambda tenant_id, doc_dict, _kb_map: calls["run"].append((tenant_id, doc_dict)))
|
||||
res = _run(module.run.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert calls["clear"] == ["doc1"]
|
||||
assert calls["filter_delete"] == [True]
|
||||
assert calls["docstore_delete"] == ["doc1"]
|
||||
assert calls["run"] == [("tenant1", {"id": "doc1"})]
|
||||
|
||||
apply_request({"doc_ids": ["doc1"], "run": module.TaskStatus.RUNNING.value, "apply_kb": True})
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
res = _run(module.run.__wrapped__())
|
||||
assert res["code"] == 500
|
||||
assert "Can't find this dataset!" in res["message"]
|
||||
|
||||
apply_request({"doc_ids": ["doc1"], "run": module.TaskStatus.RUNNING.value})
|
||||
|
||||
def raise_run_error(*_args, **_kwargs):
|
||||
raise RuntimeError("run boom")
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "run", raise_run_error)
|
||||
res = _run(module.run.__wrapped__())
|
||||
assert res["code"] == 500
|
||||
assert "run boom" in res["message"]
|
||||
|
||||
|
||||
# @pytest.mark.skip
|
||||
class TestDocumentsParseStop:
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
@ -21,6 +22,10 @@ from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
@ -75,6 +80,32 @@ class TestDocumentsDeletion:
|
||||
assert res["message"] == "No authorization.", res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestDocumentsDeletionUnit:
|
||||
def test_rm_string_doc_id_normalization_success_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
captured = {}
|
||||
|
||||
async def fake_request_json():
|
||||
return {"doc_id": "doc1"}
|
||||
|
||||
async def fake_thread_pool_exec(func, doc_ids, user_id):
|
||||
captured["func"] = func
|
||||
captured["doc_ids"] = doc_ids
|
||||
captured["user_id"] = user_id
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
monkeypatch.setattr(module.DocumentService, "accessible4deletion", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
|
||||
res = _run(module.rm.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
assert captured["func"] == module.FileService.delete_docs
|
||||
assert captured["doc_ids"] == ["doc1"]
|
||||
assert captured["user_id"] == module.current_user.id
|
||||
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_deletion(WebApiAuth, add_dataset, tmp_path):
|
||||
count = 100
|
||||
|
||||
@ -14,8 +14,9 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import sys
|
||||
import string
|
||||
from types import SimpleNamespace
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
@ -333,6 +334,136 @@ class TestDocumentsUploadUnit:
|
||||
assert res["code"] == 102
|
||||
assert "file format" in res["message"]
|
||||
|
||||
def test_upload_and_parse_matrix_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"conversation_id": "conv-1"}, files=_DummyFiles({"file": [_DummyFile("")]})))
|
||||
res = _run(module.upload_and_parse.__wrapped__())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert res["message"] == "No file selected!"
|
||||
|
||||
files = _DummyFiles({"file": [_DummyFile("note.txt")]})
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"conversation_id": "conv-1"}, files=files))
|
||||
monkeypatch.setattr(module, "doc_upload_and_parse", lambda _conv_id, _files, _uid: ["doc-1"])
|
||||
res = _run(module.upload_and_parse.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == ["doc-1"]
|
||||
|
||||
def test_parse_url_and_multipart_matrix_unit(self, document_app_module, monkeypatch, tmp_path):
|
||||
module = document_app_module
|
||||
|
||||
async def req_invalid_url():
|
||||
return {"url": "not-a-url"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", req_invalid_url)
|
||||
monkeypatch.setattr(module, "is_valid_url", lambda _url: False)
|
||||
res = _run(module.parse())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert res["message"] == "The URL format is invalid"
|
||||
|
||||
webdriver_mod = ModuleType("seleniumwire.webdriver")
|
||||
|
||||
class _FakeChromeOptions:
|
||||
def __init__(self):
|
||||
self.args = []
|
||||
self.experimental = {}
|
||||
|
||||
def add_argument(self, arg):
|
||||
self.args.append(arg)
|
||||
|
||||
def add_experimental_option(self, key, value):
|
||||
self.experimental[key] = value
|
||||
|
||||
class _Req:
|
||||
def __init__(self, headers):
|
||||
self.response = SimpleNamespace(headers=headers)
|
||||
|
||||
class _FakeDriver:
|
||||
def __init__(self, requests, page_source):
|
||||
self.requests = requests
|
||||
self.page_source = page_source
|
||||
self.quit_called = False
|
||||
self.visited = []
|
||||
self.options = None
|
||||
|
||||
def get(self, url):
|
||||
self.visited.append(url)
|
||||
|
||||
def quit(self):
|
||||
self.quit_called = True
|
||||
|
||||
queue = []
|
||||
created = []
|
||||
|
||||
def _fake_chrome(options=None):
|
||||
driver = queue.pop(0)
|
||||
driver.options = options
|
||||
created.append(driver)
|
||||
return driver
|
||||
|
||||
webdriver_mod.Chrome = _fake_chrome
|
||||
webdriver_mod.ChromeOptions = _FakeChromeOptions
|
||||
|
||||
seleniumwire_mod = ModuleType("seleniumwire")
|
||||
seleniumwire_mod.webdriver = webdriver_mod
|
||||
monkeypatch.setitem(sys.modules, "seleniumwire", seleniumwire_mod)
|
||||
monkeypatch.setitem(sys.modules, "seleniumwire.webdriver", webdriver_mod)
|
||||
monkeypatch.setattr(module, "get_project_base_directory", lambda: str(tmp_path))
|
||||
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
|
||||
|
||||
class _Parser:
|
||||
def parser_txt(self, page_source):
|
||||
assert "page" in page_source
|
||||
return ["section1", "section2"]
|
||||
|
||||
monkeypatch.setattr(module, "RAGFlowHtmlParser", lambda: _Parser())
|
||||
queue.append(_FakeDriver([_Req({"x": "1"}), _Req({"y": "2"})], "<html>page</html>"))
|
||||
|
||||
async def req_url_html():
|
||||
return {"url": "http://example.com/html"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", req_url_html)
|
||||
res = _run(module.parse())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == "section1\nsection2"
|
||||
assert created[-1].quit_called is True
|
||||
|
||||
(tmp_path / "logs" / "downloads").mkdir(parents=True, exist_ok=True)
|
||||
(tmp_path / "logs" / "downloads" / "doc.txt").write_bytes(b"downloaded-bytes")
|
||||
queue.append(_FakeDriver([_Req({"content-disposition": 'attachment; filename="doc.txt"'})], "<html>file</html>"))
|
||||
captured = {}
|
||||
|
||||
def parse_docs_read(files, _uid):
|
||||
captured["filename"] = files[0].filename
|
||||
captured["content"] = files[0].read()
|
||||
return "parsed-download"
|
||||
|
||||
monkeypatch.setattr(module.FileService, "parse_docs", parse_docs_read)
|
||||
|
||||
async def req_url_file():
|
||||
return {"url": "http://example.com/file"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", req_url_file)
|
||||
res = _run(module.parse())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == "parsed-download"
|
||||
assert captured["filename"] == "doc.txt"
|
||||
assert captured["content"] == b"downloaded-bytes"
|
||||
|
||||
async def req_no_url():
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", req_no_url)
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(files=_DummyFiles()))
|
||||
res = _run(module.parse())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert res["message"] == "No file part!"
|
||||
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(files=_DummyFiles({"file": [_DummyFile("f1.txt")]})))
|
||||
monkeypatch.setattr(module.FileService, "parse_docs", lambda _files, _uid: "parsed-upload")
|
||||
res = _run(module.parse())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == "parsed-upload"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestWebCrawlUnit:
|
||||
|
||||
@ -0,0 +1,575 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _Args(dict):
|
||||
def get(self, key, default=None):
|
||||
return super().get(key, default)
|
||||
|
||||
|
||||
class _DummyRetCode:
|
||||
SUCCESS = 0
|
||||
EXCEPTION_ERROR = 100
|
||||
ARGUMENT_ERROR = 101
|
||||
DATA_ERROR = 102
|
||||
OPERATING_ERROR = 103
|
||||
AUTHENTICATION_ERROR = 109
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
async def _request_json():
|
||||
return payload
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", _request_json)
|
||||
|
||||
|
||||
def _set_request_args(monkeypatch, module, args=None):
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args or {})))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_evaluation_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = SimpleNamespace(args=_Args())
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
constants_mod.RetCode = _DummyRetCode
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
common_pkg.constants = constants_mod
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="tenant-1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
api_pkg.apps = apps_mod
|
||||
|
||||
db_pkg = ModuleType("api.db")
|
||||
db_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
|
||||
api_pkg.db = db_pkg
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
evaluation_service_mod = ModuleType("api.db.services.evaluation_service")
|
||||
|
||||
class _EvaluationService:
|
||||
@staticmethod
|
||||
def create_dataset(**_kwargs):
|
||||
return True, "dataset-1"
|
||||
|
||||
@staticmethod
|
||||
def list_datasets(**_kwargs):
|
||||
return {"datasets": [], "total": 0}
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(_dataset_id):
|
||||
return {"id": _dataset_id}
|
||||
|
||||
@staticmethod
|
||||
def update_dataset(_dataset_id, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def delete_dataset(_dataset_id):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def add_test_case(**_kwargs):
|
||||
return True, "case-1"
|
||||
|
||||
@staticmethod
|
||||
def import_test_cases(**_kwargs):
|
||||
return 0, 0
|
||||
|
||||
@staticmethod
|
||||
def get_test_cases(_dataset_id):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def delete_test_case(_case_id):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def start_evaluation(**_kwargs):
|
||||
return True, "run-1"
|
||||
|
||||
@staticmethod
|
||||
def get_run_results(_run_id):
|
||||
return {"id": _run_id}
|
||||
|
||||
@staticmethod
|
||||
def get_recommendations(_run_id):
|
||||
return []
|
||||
|
||||
evaluation_service_mod.EvaluationService = _EvaluationService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.evaluation_service", evaluation_service_mod)
|
||||
|
||||
utils_pkg = ModuleType("api.utils")
|
||||
utils_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.utils", utils_pkg)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
async def _default_request_json():
|
||||
return {}
|
||||
|
||||
def _get_data_error_result(code=_DummyRetCode.DATA_ERROR, message="Sorry! Data missing!"):
|
||||
return {"code": code, "message": message}
|
||||
|
||||
def _get_json_result(code=_DummyRetCode.SUCCESS, message="success", data=None):
|
||||
return {"code": code, "message": message, "data": data}
|
||||
|
||||
def _server_error_response(error):
|
||||
return {"code": _DummyRetCode.EXCEPTION_ERROR, "message": repr(error)}
|
||||
|
||||
def _validate_request(*_args, **_kwargs):
|
||||
def _decorator(func):
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
api_utils_mod.get_data_error_result = _get_data_error_result
|
||||
api_utils_mod.get_json_result = _get_json_result
|
||||
api_utils_mod.get_request_json = _default_request_json
|
||||
api_utils_mod.server_error_response = _server_error_response
|
||||
api_utils_mod.validate_request = _validate_request
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
utils_pkg.api_utils = api_utils_mod
|
||||
|
||||
module_name = "test_evaluation_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "evaluation_app.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_dataset_routes_matrix_unit(monkeypatch):
|
||||
module = _load_evaluation_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": " data-1 ", "description": "desc", "kb_ids": ["kb-1"]})
|
||||
monkeypatch.setattr(module.EvaluationService, "create_dataset", lambda **_kwargs: (True, "dataset-ok"))
|
||||
res = _run(module.create_dataset())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["dataset_id"] == "dataset-ok"
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": " ", "kb_ids": ["kb-1"]})
|
||||
res = _run(module.create_dataset())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "empty" in res["message"].lower()
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "data-2", "kb_ids": "kb-1"})
|
||||
res = _run(module.create_dataset())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "kb_ids" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "data-3", "kb_ids": ["kb-1"]})
|
||||
monkeypatch.setattr(module.EvaluationService, "create_dataset", lambda **_kwargs: (False, "create failed"))
|
||||
res = _run(module.create_dataset())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert res["message"] == "create failed"
|
||||
|
||||
def _raise_create(**_kwargs):
|
||||
raise RuntimeError("create boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "create_dataset", _raise_create)
|
||||
res = _run(module.create_dataset())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "create boom" in res["message"]
|
||||
|
||||
_set_request_args(monkeypatch, module, {"page": "2", "page_size": "3"})
|
||||
monkeypatch.setattr(module.EvaluationService, "list_datasets", lambda **_kwargs: {"datasets": [{"id": "a"}], "total": 1})
|
||||
res = _run(module.list_datasets())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 1
|
||||
|
||||
_set_request_args(monkeypatch, module, {"page": "x"})
|
||||
res = _run(module.list_datasets())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_dataset", lambda _dataset_id: None)
|
||||
res = _run(module.get_dataset("dataset-1"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "not found" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_dataset", lambda _dataset_id: {"id": _dataset_id})
|
||||
res = _run(module.get_dataset("dataset-2"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "dataset-2"
|
||||
|
||||
def _raise_get(_dataset_id):
|
||||
raise RuntimeError("get dataset boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_dataset", _raise_get)
|
||||
res = _run(module.get_dataset("dataset-3"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "get dataset boom" in res["message"]
|
||||
|
||||
captured = {}
|
||||
|
||||
def _update(dataset_id, **kwargs):
|
||||
captured["dataset_id"] = dataset_id
|
||||
captured["kwargs"] = kwargs
|
||||
return True
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"id": "forbidden",
|
||||
"tenant_id": "forbidden",
|
||||
"created_by": "forbidden",
|
||||
"create_time": 123,
|
||||
"name": "new-name",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(module.EvaluationService, "update_dataset", _update)
|
||||
res = _run(module.update_dataset("dataset-4"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["dataset_id"] == "dataset-4"
|
||||
assert captured["dataset_id"] == "dataset-4"
|
||||
assert "id" not in captured["kwargs"]
|
||||
assert "tenant_id" not in captured["kwargs"]
|
||||
assert "created_by" not in captured["kwargs"]
|
||||
assert "create_time" not in captured["kwargs"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "new-name"})
|
||||
monkeypatch.setattr(module.EvaluationService, "update_dataset", lambda _dataset_id, **_kwargs: False)
|
||||
res = _run(module.update_dataset("dataset-5"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "failed" in res["message"].lower()
|
||||
|
||||
def _raise_update(_dataset_id, **_kwargs):
|
||||
raise RuntimeError("update boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "update_dataset", _raise_update)
|
||||
res = _run(module.update_dataset("dataset-6"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "update boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "delete_dataset", lambda _dataset_id: False)
|
||||
res = _run(module.delete_dataset("dataset-7"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "failed" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "delete_dataset", lambda _dataset_id: True)
|
||||
res = _run(module.delete_dataset("dataset-8"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["dataset_id"] == "dataset-8"
|
||||
|
||||
def _raise_delete(_dataset_id):
|
||||
raise RuntimeError("delete dataset boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "delete_dataset", _raise_delete)
|
||||
res = _run(module.delete_dataset("dataset-9"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "delete dataset boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_test_case_routes_matrix_unit(monkeypatch):
|
||||
module = _load_evaluation_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"question": " "})
|
||||
res = _run(module.add_test_case("dataset-1"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "question" in res["message"].lower()
|
||||
|
||||
_set_request_json(monkeypatch, module, {"question": "q1"})
|
||||
monkeypatch.setattr(module.EvaluationService, "add_test_case", lambda **_kwargs: (False, "add failed"))
|
||||
res = _run(module.add_test_case("dataset-2"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "add failed" in res["message"]
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"question": "q2",
|
||||
"reference_answer": "a2",
|
||||
"relevant_doc_ids": ["doc-1"],
|
||||
"relevant_chunk_ids": ["chunk-1"],
|
||||
"metadata": {"k": "v"},
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(module.EvaluationService, "add_test_case", lambda **_kwargs: (True, "case-ok"))
|
||||
res = _run(module.add_test_case("dataset-3"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["case_id"] == "case-ok"
|
||||
|
||||
def _raise_add(**_kwargs):
|
||||
raise RuntimeError("add case boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "add_test_case", _raise_add)
|
||||
res = _run(module.add_test_case("dataset-4"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "add case boom" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"cases": {}})
|
||||
res = _run(module.import_test_cases("dataset-5"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "cases" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"cases": [{"question": "q1"}, {"question": "q2"}]})
|
||||
monkeypatch.setattr(module.EvaluationService, "import_test_cases", lambda **_kwargs: (2, 0))
|
||||
res = _run(module.import_test_cases("dataset-6"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["success_count"] == 2
|
||||
assert res["data"]["failure_count"] == 0
|
||||
assert res["data"]["total"] == 2
|
||||
|
||||
def _raise_import(**_kwargs):
|
||||
raise RuntimeError("import boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "import_test_cases", _raise_import)
|
||||
res = _run(module.import_test_cases("dataset-7"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "import boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_test_cases", lambda _dataset_id: [{"id": "case-1"}])
|
||||
res = _run(module.get_test_cases("dataset-8"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 1
|
||||
assert res["data"]["cases"][0]["id"] == "case-1"
|
||||
|
||||
def _raise_get_cases(_dataset_id):
|
||||
raise RuntimeError("get cases boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_test_cases", _raise_get_cases)
|
||||
res = _run(module.get_test_cases("dataset-9"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "get cases boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "delete_test_case", lambda _case_id: False)
|
||||
res = _run(module.delete_test_case("case-1"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "failed" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "delete_test_case", lambda _case_id: True)
|
||||
res = _run(module.delete_test_case("case-2"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["case_id"] == "case-2"
|
||||
|
||||
def _raise_delete_case(_case_id):
|
||||
raise RuntimeError("delete case boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "delete_test_case", _raise_delete_case)
|
||||
res = _run(module.delete_test_case("case-3"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "delete case boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_run_and_recommendation_routes_matrix_unit(monkeypatch):
|
||||
module = _load_evaluation_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"dataset_id": "d1", "dialog_id": "dialog-1", "name": "run 1"})
|
||||
monkeypatch.setattr(module.EvaluationService, "start_evaluation", lambda **_kwargs: (False, "start failed"))
|
||||
res = _run(module.start_evaluation())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "start failed" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "start_evaluation", lambda **_kwargs: (True, "run-ok"))
|
||||
res = _run(module.start_evaluation())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["run_id"] == "run-ok"
|
||||
|
||||
def _raise_start(**_kwargs):
|
||||
raise RuntimeError("start boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "start_evaluation", _raise_start)
|
||||
res = _run(module.start_evaluation())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "start boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: None)
|
||||
res = _run(module.get_evaluation_run("run-1"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "not found" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: {"id": _run_id})
|
||||
res = _run(module.get_evaluation_run("run-2"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "run-2"
|
||||
|
||||
def _raise_get_run(_run_id):
|
||||
raise RuntimeError("get run boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_run_results", _raise_get_run)
|
||||
res = _run(module.get_evaluation_run("run-3"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "get run boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: None)
|
||||
res = _run(module.get_run_results("run-4"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "not found" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: {"id": _run_id, "score": 0.9})
|
||||
res = _run(module.get_run_results("run-5"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "run-5"
|
||||
|
||||
def _raise_results(_run_id):
|
||||
raise RuntimeError("get results boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_run_results", _raise_results)
|
||||
res = _run(module.get_run_results("run-6"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "get results boom" in res["message"]
|
||||
|
||||
res = _run(module.list_evaluation_runs())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 0
|
||||
|
||||
def _raise_json_list(*_args, **_kwargs):
|
||||
raise RuntimeError("list runs boom")
|
||||
|
||||
monkeypatch.setattr(module, "get_json_result", _raise_json_list)
|
||||
res = _run(module.list_evaluation_runs())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "list runs boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "get_json_result", lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data})
|
||||
res = _run(module.delete_evaluation_run("run-7"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["run_id"] == "run-7"
|
||||
|
||||
def _raise_json_delete(*_args, **_kwargs):
|
||||
raise RuntimeError("delete run boom")
|
||||
|
||||
monkeypatch.setattr(module, "get_json_result", _raise_json_delete)
|
||||
res = _run(module.delete_evaluation_run("run-8"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "delete run boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "get_json_result", lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data})
|
||||
monkeypatch.setattr(module.EvaluationService, "get_recommendations", lambda _run_id: [{"name": "cfg-1"}])
|
||||
res = _run(module.get_recommendations("run-9"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["recommendations"][0]["name"] == "cfg-1"
|
||||
|
||||
def _raise_recommend(_run_id):
|
||||
raise RuntimeError("recommend boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_recommendations", _raise_recommend)
|
||||
res = _run(module.get_recommendations("run-10"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "recommend boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_compare_export_and_evaluate_single_matrix_unit(monkeypatch):
|
||||
module = _load_evaluation_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"run_ids": ["run-1"]})
|
||||
res = _run(module.compare_runs())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "at least 2" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"run_ids": ["run-1", "run-2"]})
|
||||
res = _run(module.compare_runs())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["comparison"] == {}
|
||||
|
||||
def _raise_json_compare(*_args, **_kwargs):
|
||||
raise RuntimeError("compare boom")
|
||||
|
||||
monkeypatch.setattr(module, "get_json_result", _raise_json_compare)
|
||||
_set_request_json(monkeypatch, module, {"run_ids": ["run-1", "run-2", "run-3"]})
|
||||
res = _run(module.compare_runs())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "compare boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "get_json_result", lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data})
|
||||
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: None)
|
||||
res = _run(module.export_results("run-11"))
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "not found" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_run_results", lambda _run_id: {"id": _run_id, "rows": []})
|
||||
res = _run(module.export_results("run-12"))
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "run-12"
|
||||
|
||||
def _raise_export(_run_id):
|
||||
raise RuntimeError("export boom")
|
||||
|
||||
monkeypatch.setattr(module.EvaluationService, "get_run_results", _raise_export)
|
||||
res = _run(module.export_results("run-13"))
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "export boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "get_json_result", lambda code=0, message="success", data=None: {"code": code, "message": message, "data": data})
|
||||
res = _run(module.evaluate_single())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["answer"] == ""
|
||||
assert res["data"]["metrics"] == {}
|
||||
assert res["data"]["retrieved_chunks"] == []
|
||||
|
||||
def _raise_json_single(*_args, **_kwargs):
|
||||
raise RuntimeError("single boom")
|
||||
|
||||
monkeypatch.setattr(module, "get_json_result", _raise_json_single)
|
||||
res = _run(module.evaluate_single())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "single boom" in res["message"]
|
||||
@ -0,0 +1,367 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import importlib.util
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _DummyFile:
|
||||
def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1):
|
||||
self.id = file_id
|
||||
self.type = file_type
|
||||
self.name = name
|
||||
self.location = location
|
||||
self.size = size
|
||||
|
||||
|
||||
class _FalsyFile(_DummyFile):
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload_state):
|
||||
async def _req_json():
|
||||
return deepcopy(payload_state)
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", _req_json)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_file2document_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="user-1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
api_pkg.apps = apps_mod
|
||||
|
||||
db_pkg = ModuleType("api.db")
|
||||
db_pkg.__path__ = []
|
||||
|
||||
class _FileType(Enum):
|
||||
FOLDER = "folder"
|
||||
DOC = "doc"
|
||||
|
||||
db_pkg.FileType = _FileType
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
|
||||
api_pkg.db = db_pkg
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
file2document_mod = ModuleType("api.db.services.file2document_service")
|
||||
|
||||
class _StubFile2DocumentService:
|
||||
@staticmethod
|
||||
def get_by_file_id(_file_id):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def delete_by_file_id(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def insert(_payload):
|
||||
return SimpleNamespace(to_json=lambda: {})
|
||||
|
||||
file2document_mod.File2DocumentService = _StubFile2DocumentService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_mod)
|
||||
services_pkg.file2document_service = file2document_mod
|
||||
|
||||
file_service_mod = ModuleType("api.db.services.file_service")
|
||||
|
||||
class _StubFileService:
|
||||
@staticmethod
|
||||
def get_by_ids(_file_ids):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_all_innermost_file_ids(_file_id, _acc):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_file_id):
|
||||
return True, _DummyFile(_file_id, _FileType.DOC.value)
|
||||
|
||||
file_service_mod.FileService = _StubFileService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod)
|
||||
services_pkg.file_service = file_service_mod
|
||||
|
||||
kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
|
||||
|
||||
class _StubKnowledgebaseService:
|
||||
@staticmethod
|
||||
def get_by_id(_kb_id):
|
||||
return False, None
|
||||
|
||||
kb_service_mod.KnowledgebaseService = _StubKnowledgebaseService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
|
||||
services_pkg.knowledgebase_service = kb_service_mod
|
||||
|
||||
document_service_mod = ModuleType("api.db.services.document_service")
|
||||
|
||||
class _StubDocumentService:
|
||||
@staticmethod
|
||||
def get_by_id(doc_id):
|
||||
return True, SimpleNamespace(id=doc_id)
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_id(_doc_id):
|
||||
return "tenant-1"
|
||||
|
||||
@staticmethod
|
||||
def remove_document(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def insert(_payload):
|
||||
return SimpleNamespace(id="doc-1")
|
||||
|
||||
document_service_mod.DocumentService = _StubDocumentService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
|
||||
services_pkg.document_service = document_service_mod
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
def get_json_result(data=None, message="", code=0):
|
||||
return {"code": code, "data": data, "message": message}
|
||||
|
||||
def get_data_error_result(message=""):
|
||||
return {"code": 102, "data": None, "message": message}
|
||||
|
||||
async def get_request_json():
|
||||
return {}
|
||||
|
||||
def server_error_response(err):
|
||||
return {"code": 500, "data": None, "message": str(err)}
|
||||
|
||||
def validate_request(*_keys):
|
||||
def _decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def _wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return _wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
api_utils_mod.get_json_result = get_json_result
|
||||
api_utils_mod.get_data_error_result = get_data_error_result
|
||||
api_utils_mod.get_request_json = get_request_json
|
||||
api_utils_mod.server_error_response = server_error_response
|
||||
api_utils_mod.validate_request = validate_request
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
misc_utils_mod = ModuleType("common.misc_utils")
|
||||
misc_utils_mod.get_uuid = lambda: "uuid"
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
|
||||
class _RetCode:
|
||||
ARGUMENT_ERROR = 101
|
||||
|
||||
constants_mod.RetCode = _RetCode
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
module_name = "test_file2document_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "file2document_app.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_convert_branch_matrix_unit(monkeypatch):
|
||||
module = _load_file2document_module(monkeypatch)
|
||||
req_state = {"kb_ids": ["kb-1"], "file_ids": ["f1"]}
|
||||
_set_request_json(monkeypatch, module, req_state)
|
||||
|
||||
events = {"deleted": []}
|
||||
|
||||
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_FalsyFile("f1", module.FileType.DOC.value)])
|
||||
res = _run(module.convert())
|
||||
assert res["message"] == "File not found!"
|
||||
|
||||
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("f1", module.FileType.DOC.value)])
|
||||
monkeypatch.setattr(module.File2DocumentService, "get_by_file_id", lambda _file_id: [SimpleNamespace(document_id="doc-1")])
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
res = _run(module.convert())
|
||||
assert res["message"] == "Document not found!"
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id=_doc_id)))
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: None)
|
||||
res = _run(module.convert())
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
|
||||
monkeypatch.setattr(module.DocumentService, "remove_document", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.convert())
|
||||
assert "Document removal" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "remove_document", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.File2DocumentService, "get_by_file_id", lambda _file_id: [])
|
||||
monkeypatch.setattr(module.File2DocumentService, "delete_by_file_id", lambda file_id: events["deleted"].append(file_id))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
res = _run(module.convert())
|
||||
assert res["message"] == "Can't find this dataset!"
|
||||
assert events["deleted"] == ["f1"]
|
||||
|
||||
kb = SimpleNamespace(id="kb-1", parser_id="naive", pipeline_id="p1", parser_config={})
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module.FileService, "get_by_id", lambda _file_id: (False, None))
|
||||
res = _run(module.convert())
|
||||
assert res["message"] == "Can't find this file!"
|
||||
|
||||
req_state["file_ids"] = ["folder-1"]
|
||||
monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("folder-1", module.FileType.FOLDER.value, name="folder")])
|
||||
monkeypatch.setattr(module.FileService, "get_all_innermost_file_ids", lambda _file_id, _acc: ["inner-1"])
|
||||
monkeypatch.setattr(
|
||||
module.FileService,
|
||||
"get_by_id",
|
||||
lambda _file_id: (True, _DummyFile("inner-1", module.FileType.DOC.value, name="inner.txt", location="inner.loc", size=2)),
|
||||
)
|
||||
monkeypatch.setattr(module.DocumentService, "insert", lambda _payload: SimpleNamespace(id="doc-new"))
|
||||
monkeypatch.setattr(
|
||||
module.File2DocumentService,
|
||||
"insert",
|
||||
lambda _payload: SimpleNamespace(to_json=lambda: {"file_id": "inner-1", "document_id": "doc-new"}),
|
||||
)
|
||||
res = _run(module.convert())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == [{"file_id": "inner-1", "document_id": "doc-new"}]
|
||||
|
||||
req_state["file_ids"] = ["f1"]
|
||||
monkeypatch.setattr(
|
||||
module.FileService,
|
||||
"get_by_ids",
|
||||
lambda _ids: (_ for _ in ()).throw(RuntimeError("convert boom")),
|
||||
)
|
||||
res = _run(module.convert())
|
||||
assert res["code"] == 500
|
||||
assert "convert boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_rm_branch_matrix_unit(monkeypatch):
|
||||
module = _load_file2document_module(monkeypatch)
|
||||
req_state = {"file_ids": []}
|
||||
_set_request_json(monkeypatch, module, req_state)
|
||||
|
||||
deleted = []
|
||||
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR
|
||||
assert 'Lack of "Files ID"' in res["message"]
|
||||
|
||||
req_state["file_ids"] = ["f1"]
|
||||
monkeypatch.setattr(module.File2DocumentService, "get_by_file_id", lambda _file_id: [])
|
||||
res = _run(module.rm())
|
||||
assert res["message"] == "Inform not found!"
|
||||
|
||||
monkeypatch.setattr(module.File2DocumentService, "get_by_file_id", lambda _file_id: [None])
|
||||
res = _run(module.rm())
|
||||
assert res["message"] == "Inform not found!"
|
||||
|
||||
monkeypatch.setattr(module.File2DocumentService, "get_by_file_id", lambda _file_id: [SimpleNamespace(document_id="doc-1")])
|
||||
monkeypatch.setattr(module.File2DocumentService, "delete_by_file_id", lambda file_id: deleted.append(file_id))
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
res = _run(module.rm())
|
||||
assert res["message"] == "Document not found!"
|
||||
assert deleted == ["f1"]
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, SimpleNamespace(id=_doc_id)))
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: None)
|
||||
res = _run(module.rm())
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
|
||||
monkeypatch.setattr(module.DocumentService, "remove_document", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.rm())
|
||||
assert "Document removal" in res["message"]
|
||||
|
||||
req_state["file_ids"] = ["f1", "f2"]
|
||||
monkeypatch.setattr(
|
||||
module.File2DocumentService,
|
||||
"get_by_file_id",
|
||||
lambda file_id: [SimpleNamespace(document_id=f"doc-{file_id}")],
|
||||
)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda doc_id: (True, SimpleNamespace(id=doc_id)))
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
|
||||
monkeypatch.setattr(module.DocumentService, "remove_document", lambda *_args, **_kwargs: True)
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.File2DocumentService,
|
||||
"get_by_file_id",
|
||||
lambda _file_id: (_ for _ in ()).throw(RuntimeError("rm boom")),
|
||||
)
|
||||
req_state["file_ids"] = ["boom"]
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == 500
|
||||
assert "rm boom" in res["message"]
|
||||
1226
test/testcases/test_web_api/test_file_app/test_file_routes_unit.py
Normal file
1226
test/testcases/test_web_api/test_file_app/test_file_routes_unit.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -26,6 +26,8 @@ from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings("ignore:.*joblib will operate in serial mode.*:UserWarning")
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
@ -169,6 +171,16 @@ def _base_update_payload(**kwargs):
|
||||
return payload
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_branches(monkeypatch):
|
||||
module = _load_kb_module(monkeypatch)
|
||||
@ -1046,3 +1058,236 @@ def test_unbind_task_branch_matrix(monkeypatch):
|
||||
res = route()
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
|
||||
assert "cannot delete task" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch):
|
||||
module = _load_kb_module(monkeypatch)
|
||||
route = inspect.unwrap(module.check_embedding)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-1")))
|
||||
monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx")
|
||||
|
||||
class _FlipBool:
|
||||
def __init__(self):
|
||||
self._calls = 0
|
||||
|
||||
def __bool__(self):
|
||||
self._calls += 1
|
||||
return self._calls == 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.re,
|
||||
"sub",
|
||||
lambda _pattern, _repl, text: _FlipBool() if "TRIGGER_NO_TEXT" in str(text) else text,
|
||||
)
|
||||
|
||||
def _fixed_sample(population, k):
|
||||
return list(population)[:k]
|
||||
|
||||
monkeypatch.setattr(module.random, "sample", _fixed_sample)
|
||||
|
||||
class _DocStore:
|
||||
def __init__(self, total, ids_by_offset, docs):
|
||||
self.total = total
|
||||
self.ids_by_offset = ids_by_offset
|
||||
self.docs = docs
|
||||
|
||||
def search(self, select_fields, **kwargs):
|
||||
if not select_fields:
|
||||
return {"kind": "total"}
|
||||
return {"kind": "sample", "offset": kwargs["offset"]}
|
||||
|
||||
def get_total(self, _res):
|
||||
return self.total
|
||||
|
||||
def get_doc_ids(self, res):
|
||||
return self.ids_by_offset.get(res.get("offset", -1), [])
|
||||
|
||||
def get(self, cid, _index_name, _kb_ids):
|
||||
return self.docs.get(cid, {})
|
||||
|
||||
class _EmbModel:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def encode(self, pair):
|
||||
title, _txt = pair
|
||||
self.calls.append(title)
|
||||
if title == "Doc Mix":
|
||||
# title+content mix wins over content only path.
|
||||
return [module.np.array([1.0, 0.0]), module.np.array([0.0, 1.0])], None
|
||||
if title == "Doc High":
|
||||
return [module.np.array([1.0, 0.0]), module.np.array([1.0, 0.0])], None
|
||||
return [module.np.array([0.0, 1.0]), module.np.array([0.0, 1.0])], None
|
||||
|
||||
emb_model = _EmbModel()
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: emb_model)
|
||||
|
||||
low_docs = {
|
||||
"chunk-no-vec": {
|
||||
"doc_id": "doc-no-vec",
|
||||
"docnm_kwd": "Doc No Vec",
|
||||
"content_with_weight": "body-no-vec",
|
||||
"page_num_int": 1,
|
||||
"position_int": 1,
|
||||
"top_int": 1,
|
||||
},
|
||||
"chunk-bad-type": {
|
||||
"doc_id": "doc-bad-type",
|
||||
"docnm_kwd": "Doc Bad Type",
|
||||
"content_with_weight": "body-bad-type",
|
||||
"question_kwd": [],
|
||||
"q_vec": {"bad": "type"},
|
||||
"page_num_int": 1,
|
||||
"position_int": 2,
|
||||
"top_int": 2,
|
||||
},
|
||||
"chunk-low-zero": {
|
||||
"doc_id": "doc-low-zero",
|
||||
"docnm_kwd": "Doc Low Zero",
|
||||
"content_with_weight": "body-low",
|
||||
"question_kwd": [],
|
||||
"q_vec": "0\t0",
|
||||
"page_num_int": 1,
|
||||
"position_int": 3,
|
||||
"top_int": 3,
|
||||
},
|
||||
"chunk-no-text": {
|
||||
"doc_id": "doc-no-text",
|
||||
"docnm_kwd": "Doc No Text",
|
||||
"content_with_weight": "TRIGGER_NO_TEXT",
|
||||
"q_vec": [1.0, 0.0],
|
||||
"page_num_int": 1,
|
||||
"position_int": 4,
|
||||
"top_int": 4,
|
||||
},
|
||||
"chunk-mix": {
|
||||
"doc_id": "doc-mix",
|
||||
"docnm_kwd": "Doc Mix",
|
||||
"content_with_weight": "body-mix",
|
||||
"q_vec": [1.0, 0.0],
|
||||
"page_num_int": 1,
|
||||
"position_int": 5,
|
||||
"top_int": 5,
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.settings,
|
||||
"docStoreConn",
|
||||
_DocStore(
|
||||
total=6,
|
||||
ids_by_offset={
|
||||
0: [],
|
||||
1: ["chunk-no-vec"],
|
||||
2: ["chunk-bad-type"],
|
||||
3: ["chunk-low-zero"],
|
||||
4: ["chunk-no-text"],
|
||||
5: ["chunk-mix"],
|
||||
},
|
||||
docs=low_docs,
|
||||
),
|
||||
)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 6})
|
||||
res = _run(route())
|
||||
assert res["code"] == module.RetCode.NOT_EFFECTIVE, res
|
||||
assert "average similarity" in res["message"], res
|
||||
summary = res["data"]["summary"]
|
||||
assert summary["sampled"] == 5, summary
|
||||
assert summary["valid"] == 2, summary
|
||||
reasons = {item.get("reason") for item in res["data"]["results"] if "reason" in item}
|
||||
assert "no_stored_vector" in reasons, res
|
||||
assert "no_text" in reasons, res
|
||||
assert any(item.get("chunk_id") == "chunk-low-zero" and "cos_sim" in item for item in res["data"]["results"]), res
|
||||
assert summary["match_mode"] in {"content_only", "title+content"}, summary
|
||||
|
||||
high_docs = {
|
||||
"chunk-high": {
|
||||
"doc_id": "doc-high",
|
||||
"docnm_kwd": "Doc High",
|
||||
"content_with_weight": "body-high",
|
||||
"q_vec": [1.0, 0.0],
|
||||
"page_num_int": 1,
|
||||
"position_int": 1,
|
||||
"top_int": 1,
|
||||
}
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
module.settings,
|
||||
"docStoreConn",
|
||||
_DocStore(total=1, ids_by_offset={0: ["chunk-high"]}, docs=high_docs),
|
||||
)
|
||||
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 1})
|
||||
res = _run(route())
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert res["data"]["summary"]["avg_cos_sim"] > 0.9, res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_check_embedding_error_and_empty_sample_paths_unit(monkeypatch):
|
||||
module = _load_kb_module(monkeypatch)
|
||||
route = inspect.unwrap(module.check_embedding)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-1")))
|
||||
monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx")
|
||||
monkeypatch.setattr(module.random, "sample", lambda population, k: list(population)[:k])
|
||||
|
||||
class _DocStore:
|
||||
def __init__(self, total, ids_by_offset, docs):
|
||||
self.total = total
|
||||
self.ids_by_offset = ids_by_offset
|
||||
self.docs = docs
|
||||
|
||||
def search(self, select_fields, **kwargs):
|
||||
if not select_fields:
|
||||
return {"kind": "total"}
|
||||
return {"kind": "sample", "offset": kwargs["offset"]}
|
||||
|
||||
def get_total(self, _res):
|
||||
return self.total
|
||||
|
||||
def get_doc_ids(self, res):
|
||||
return self.ids_by_offset.get(res.get("offset", -1), [])
|
||||
|
||||
def get(self, cid, _index_name, _kb_ids):
|
||||
return self.docs.get(cid, {})
|
||||
|
||||
class _BoomEmbModel:
|
||||
def encode(self, _pair):
|
||||
raise RuntimeError("encode boom")
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _BoomEmbModel())
|
||||
monkeypatch.setattr(
|
||||
module.settings,
|
||||
"docStoreConn",
|
||||
_DocStore(
|
||||
total=1,
|
||||
ids_by_offset={0: ["chunk-err"]},
|
||||
docs={
|
||||
"chunk-err": {
|
||||
"doc_id": "doc-err",
|
||||
"docnm_kwd": "Doc Err",
|
||||
"content_with_weight": "body-err",
|
||||
"q_vec": [1.0, 0.0],
|
||||
"page_num_int": 1,
|
||||
"position_int": 1,
|
||||
"top_int": 1,
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 1})
|
||||
res = _run(route())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert "Embedding failure." in res["message"], res
|
||||
assert "encode boom" in res["message"], res
|
||||
|
||||
class _OkEmbModel:
|
||||
def encode(self, _pair):
|
||||
return [module.np.array([1.0, 0.0]), module.np.array([1.0, 0.0])], None
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _OkEmbModel())
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", _DocStore(total=0, ids_by_offset={}, docs={}))
|
||||
_set_request_json(monkeypatch, module, {"kb_id": "kb-1", "embd_id": "emb-1", "check_num": 1})
|
||||
with pytest.raises(UnboundLocalError):
|
||||
_run(route())
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
@ -39,18 +40,39 @@ class _ExprField:
|
||||
return (self.name, other)
|
||||
|
||||
|
||||
class _StrEnum(str):
|
||||
@property
|
||||
def value(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class _DummyTenantLLMModel:
|
||||
tenant_id = _ExprField("tenant_id")
|
||||
llm_factory = _ExprField("llm_factory")
|
||||
llm_name = _ExprField("llm_name")
|
||||
|
||||
|
||||
class _TenantLLMRow:
|
||||
def __init__(self, *, llm_name, llm_factory, model_type, api_key="key", status="1"):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_name,
|
||||
llm_factory,
|
||||
model_type,
|
||||
api_key="key",
|
||||
status="1",
|
||||
used_tokens=0,
|
||||
api_base="",
|
||||
max_tokens=8192,
|
||||
):
|
||||
self.llm_name = llm_name
|
||||
self.llm_factory = llm_factory
|
||||
self.model_type = model_type
|
||||
self.api_key = api_key
|
||||
self.status = status
|
||||
self.used_tokens = used_tokens
|
||||
self.api_base = api_base
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
@ -58,15 +80,19 @@ class _TenantLLMRow:
|
||||
"llm_factory": self.llm_factory,
|
||||
"model_type": self.model_type,
|
||||
"status": self.status,
|
||||
"used_tokens": self.used_tokens,
|
||||
"api_base": self.api_base,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
|
||||
class _LLMRow:
|
||||
def __init__(self, *, llm_name, fid, model_type, status="1"):
|
||||
def __init__(self, *, llm_name, fid, model_type, status="1", max_tokens=2048):
|
||||
self.llm_name = llm_name
|
||||
self.fid = fid
|
||||
self.model_type = model_type
|
||||
self.status = status
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
@ -74,6 +100,7 @@ class _LLMRow:
|
||||
"fid": self.fid,
|
||||
"model_type": self.model_type,
|
||||
"status": self.status,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
|
||||
@ -81,6 +108,13 @@ def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
async def _get_request_json():
|
||||
return dict(payload)
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", _get_request_json)
|
||||
|
||||
|
||||
def _load_llm_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
@ -122,6 +156,10 @@ def _load_llm_app(monkeypatch):
|
||||
def filter_delete(_filters):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_update(_filters, _payload):
|
||||
return True
|
||||
|
||||
tenant_llm_mod.LLMFactoriesService = _StubLLMFactoriesService
|
||||
tenant_llm_mod.TenantLLMService = _StubTenantLLMService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_mod)
|
||||
@ -164,13 +202,13 @@ def _load_llm_app(monkeypatch):
|
||||
constants_mod = ModuleType("common.constants")
|
||||
constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value="1"), INVALID=SimpleNamespace(value="0"))
|
||||
constants_mod.LLMType = SimpleNamespace(
|
||||
CHAT="chat",
|
||||
EMBEDDING="embedding",
|
||||
SPEECH2TEXT="speech2text",
|
||||
IMAGE2TEXT="image2text",
|
||||
RERANK="rerank",
|
||||
TTS="tts",
|
||||
OCR="ocr",
|
||||
CHAT=_StrEnum("chat"),
|
||||
EMBEDDING=_StrEnum("embedding"),
|
||||
SPEECH2TEXT=_StrEnum("speech2text"),
|
||||
IMAGE2TEXT=_StrEnum("image2text"),
|
||||
RERANK=_StrEnum("rerank"),
|
||||
TTS=_StrEnum("tts"),
|
||||
OCR=_StrEnum("ocr"),
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
@ -179,7 +217,7 @@ def _load_llm_app(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
base64_mod = ModuleType("rag.utils.base64_image")
|
||||
base64_mod.test_image = lambda _s: _s
|
||||
base64_mod.test_image = b"image-bytes"
|
||||
monkeypatch.setitem(sys.modules, "rag.utils.base64_image", base64_mod)
|
||||
|
||||
rag_llm_mod = ModuleType("rag.llm")
|
||||
@ -288,3 +326,529 @@ def test_list_app_exception_path(monkeypatch):
|
||||
res = _run(module.list_app())
|
||||
assert res["code"] == 500
|
||||
assert "query boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_factories_route_success_and_exception_unit(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
def _factory(name):
|
||||
return SimpleNamespace(name=name, to_dict=lambda n=name: {"name": n})
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_allowed_llm_factories",
|
||||
lambda: [
|
||||
_factory("OpenAI"),
|
||||
_factory("CustomFactory"),
|
||||
_factory("FastEmbed"),
|
||||
_factory("Builtin"),
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.LLMService,
|
||||
"get_all",
|
||||
lambda: [
|
||||
_LLMRow(llm_name="m1", fid="OpenAI", model_type="chat", status="1"),
|
||||
_LLMRow(llm_name="m2", fid="OpenAI", model_type="embedding", status="1"),
|
||||
_LLMRow(llm_name="m3", fid="OpenAI", model_type="rerank", status="0"),
|
||||
],
|
||||
)
|
||||
res = module.factories()
|
||||
assert res["code"] == 0
|
||||
names = [item["name"] for item in res["data"]]
|
||||
assert "FastEmbed" not in names
|
||||
assert "Builtin" not in names
|
||||
assert {"OpenAI", "CustomFactory"} == set(names)
|
||||
openai = next(item for item in res["data"] if item["name"] == "OpenAI")
|
||||
assert {"chat", "embedding"} == set(openai["model_types"])
|
||||
|
||||
monkeypatch.setattr(module, "get_allowed_llm_factories", lambda: (_ for _ in ()).throw(RuntimeError("factories boom")))
|
||||
res = module.factories()
|
||||
assert res["code"] == 500
|
||||
assert "factories boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_set_api_key_model_probe_matrix_unit(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
async def _wait_for(coro, *_args, **_kwargs):
|
||||
return await coro
|
||||
|
||||
async def _to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(module.asyncio, "wait_for", _wait_for)
|
||||
monkeypatch.setattr(module.asyncio, "to_thread", _to_thread)
|
||||
|
||||
class _EmbeddingFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def encode(self, _texts):
|
||||
return [[]], 1
|
||||
|
||||
class _EmbeddingPass:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def encode(self, _texts):
|
||||
return [[0.1]], 1
|
||||
|
||||
class _ChatFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
async def async_chat(self, *_args, **_kwargs):
|
||||
return "**ERROR** chat fail", 1
|
||||
|
||||
class _RerankFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def similarity(self, *_args, **_kwargs):
|
||||
return [], 0
|
||||
|
||||
factory = "FactoryA"
|
||||
monkeypatch.setattr(
|
||||
module.LLMService,
|
||||
"query",
|
||||
lambda **_kwargs: [
|
||||
_LLMRow(llm_name="emb", fid=factory, model_type=module.LLMType.EMBEDDING.value, max_tokens=321),
|
||||
_LLMRow(llm_name="chat", fid=factory, model_type=module.LLMType.CHAT.value, max_tokens=654),
|
||||
_LLMRow(llm_name="rerank", fid=factory, model_type=module.LLMType.RERANK.value, max_tokens=987),
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(module, "EmbeddingModel", {factory: _EmbeddingFail})
|
||||
monkeypatch.setattr(module, "ChatModel", {factory: _ChatFail})
|
||||
monkeypatch.setattr(module, "RerankModel", {factory: _RerankFail})
|
||||
|
||||
req = {"llm_factory": factory, "api_key": "k", "base_url": "http://x", "verify": True}
|
||||
_set_request_json(monkeypatch, module, req)
|
||||
res = _run(module.set_api_key())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["success"] is False
|
||||
assert "Fail to access embedding model(emb)" in res["data"]["message"]
|
||||
assert "Fail to access model(FactoryA/chat)" in res["data"]["message"]
|
||||
assert "Fail to access model(FactoryA/rerank)" in res["data"]["message"]
|
||||
|
||||
req["verify"] = False
|
||||
_set_request_json(monkeypatch, module, req)
|
||||
res = _run(module.set_api_key())
|
||||
assert res["code"] == 400
|
||||
assert "Fail to access embedding model(emb)" in res["message"]
|
||||
|
||||
calls = {"filter_update": [], "save": []}
|
||||
|
||||
def _filter_update(filters, payload):
|
||||
calls["filter_update"].append((filters, dict(payload)))
|
||||
return False
|
||||
|
||||
def _save(**kwargs):
|
||||
calls["save"].append(kwargs)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(module, "EmbeddingModel", {factory: _EmbeddingPass})
|
||||
monkeypatch.setattr(module.LLMService, "query", lambda **_kwargs: [_LLMRow(llm_name="emb-pass", fid=factory, model_type=module.LLMType.EMBEDDING.value, max_tokens=2049)])
|
||||
monkeypatch.setattr(module.TenantLLMService, "filter_update", _filter_update)
|
||||
monkeypatch.setattr(module.TenantLLMService, "save", _save)
|
||||
|
||||
success_req = {
|
||||
"llm_factory": factory,
|
||||
"api_key": "k2",
|
||||
"base_url": "http://y",
|
||||
"model_type": "chat",
|
||||
"llm_name": "manual-model",
|
||||
}
|
||||
_set_request_json(monkeypatch, module, success_req)
|
||||
res = _run(module.set_api_key())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
assert calls["filter_update"]
|
||||
assert calls["filter_update"][0][1]["model_type"] == "chat"
|
||||
assert calls["filter_update"][0][1]["llm_name"] == "manual-model"
|
||||
assert calls["filter_update"][0][1]["max_tokens"] == 2049
|
||||
assert calls["save"][0]["max_tokens"] == 2049
|
||||
assert calls["save"][0]["llm_name"] == "emb-pass"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_add_llm_factory_specific_key_assembly_unit(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
async def _wait_for(coro, *_args, **_kwargs):
|
||||
return await coro
|
||||
|
||||
async def _to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(module.asyncio, "wait_for", _wait_for)
|
||||
monkeypatch.setattr(module.asyncio, "to_thread", _to_thread)
|
||||
|
||||
allowed = [
|
||||
"VolcEngine",
|
||||
"Tencent Cloud",
|
||||
"Bedrock",
|
||||
"LocalAI",
|
||||
"HuggingFace",
|
||||
"OpenAI-API-Compatible",
|
||||
"VLLM",
|
||||
"XunFei Spark",
|
||||
"BaiduYiyan",
|
||||
"Fish Audio",
|
||||
"Google Cloud",
|
||||
"Azure-OpenAI",
|
||||
"OpenRouter",
|
||||
"MinerU",
|
||||
"PaddleOCR",
|
||||
]
|
||||
monkeypatch.setattr(module, "get_allowed_llm_factories", lambda: [SimpleNamespace(name=name) for name in allowed])
|
||||
|
||||
captured = {"chat": [], "tts": [], "filter_payloads": []}
|
||||
|
||||
class _ChatOK:
|
||||
def __init__(self, key, model_name, base_url="", **_kwargs):
|
||||
captured["chat"].append((key, model_name, base_url))
|
||||
|
||||
async def async_chat(self, *_args, **_kwargs):
|
||||
return "ok", 1
|
||||
|
||||
class _TTSOK:
|
||||
def __init__(self, key, model_name, base_url="", **_kwargs):
|
||||
captured["tts"].append((key, model_name, base_url))
|
||||
|
||||
def tts(self, _text):
|
||||
yield b"ok"
|
||||
|
||||
monkeypatch.setattr(module, "ChatModel", {name: _ChatOK for name in allowed})
|
||||
monkeypatch.setattr(module, "TTSModel", {"XunFei Spark": _TTSOK})
|
||||
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, payload: captured["filter_payloads"].append(dict(payload)) or True)
|
||||
|
||||
reject_req = {"llm_factory": "NotAllowed", "llm_name": "x", "model_type": module.LLMType.CHAT.value}
|
||||
_set_request_json(monkeypatch, module, reject_req)
|
||||
res = _run(module.add_llm())
|
||||
assert res["code"] == 400
|
||||
assert "is not allowed" in res["message"]
|
||||
|
||||
def _run_case(factory, *, model_type=module.LLMType.CHAT.value, extra=None):
|
||||
req = {"llm_factory": factory, "llm_name": "model", "model_type": model_type, "api_key": "k", "api_base": "http://api"}
|
||||
if extra:
|
||||
req.update(extra)
|
||||
_set_request_json(monkeypatch, module, req)
|
||||
out = _run(module.add_llm())
|
||||
assert out["code"] == 0
|
||||
assert out["data"] is True
|
||||
return captured["filter_payloads"][-1]
|
||||
|
||||
volc = _run_case("VolcEngine", extra={"ark_api_key": "ak", "endpoint_id": "eid"})
|
||||
assert json.loads(volc["api_key"]) == {"ark_api_key": "ak", "endpoint_id": "eid"}
|
||||
|
||||
bedrock = _run_case(
|
||||
"Bedrock",
|
||||
extra={"auth_mode": "iam", "bedrock_ak": "ak", "bedrock_sk": "sk", "bedrock_region": "r", "aws_role_arn": "arn"},
|
||||
)
|
||||
assert json.loads(bedrock["api_key"]) == {
|
||||
"auth_mode": "iam",
|
||||
"bedrock_ak": "ak",
|
||||
"bedrock_sk": "sk",
|
||||
"bedrock_region": "r",
|
||||
"aws_role_arn": "arn",
|
||||
}
|
||||
|
||||
localai = _run_case("LocalAI")
|
||||
assert localai["llm_name"] == "model___LocalAI"
|
||||
huggingface = _run_case("HuggingFace")
|
||||
assert huggingface["llm_name"] == "model___HuggingFace"
|
||||
openapi = _run_case("OpenAI-API-Compatible")
|
||||
assert openapi["llm_name"] == "model___OpenAI-API"
|
||||
vllm = _run_case("VLLM")
|
||||
assert vllm["llm_name"] == "model___VLLM"
|
||||
|
||||
spark_chat = _run_case("XunFei Spark", extra={"spark_api_password": "spark-pass"})
|
||||
assert spark_chat["api_key"] == "spark-pass"
|
||||
spark_tts = _run_case(
|
||||
"XunFei Spark",
|
||||
model_type=module.LLMType.TTS.value,
|
||||
extra={"spark_app_id": "app", "spark_api_secret": "secret", "spark_api_key": "key"},
|
||||
)
|
||||
assert json.loads(spark_tts["api_key"]) == {
|
||||
"spark_app_id": "app",
|
||||
"spark_api_secret": "secret",
|
||||
"spark_api_key": "key",
|
||||
}
|
||||
|
||||
baidu = _run_case("BaiduYiyan", extra={"yiyan_ak": "ak", "yiyan_sk": "sk"})
|
||||
assert json.loads(baidu["api_key"]) == {"yiyan_ak": "ak", "yiyan_sk": "sk"}
|
||||
fish = _run_case("Fish Audio", extra={"fish_audio_ak": "ak", "fish_audio_refid": "rid"})
|
||||
assert json.loads(fish["api_key"]) == {"fish_audio_ak": "ak", "fish_audio_refid": "rid"}
|
||||
google = _run_case(
|
||||
"Google Cloud",
|
||||
extra={"google_project_id": "pid", "google_region": "us", "google_service_account_key": "sak"},
|
||||
)
|
||||
assert json.loads(google["api_key"]) == {
|
||||
"google_project_id": "pid",
|
||||
"google_region": "us",
|
||||
"google_service_account_key": "sak",
|
||||
}
|
||||
azure = _run_case("Azure-OpenAI", extra={"api_key": "real-key", "api_version": "2024-01-01"})
|
||||
assert json.loads(azure["api_key"]) == {"api_key": "real-key", "api_version": "2024-01-01"}
|
||||
openrouter = _run_case("OpenRouter", extra={"api_key": "or-key", "provider_order": "a,b"})
|
||||
assert json.loads(openrouter["api_key"]) == {"api_key": "or-key", "provider_order": "a,b"}
|
||||
mineru = _run_case("MinerU", extra={"api_key": "m-key", "provider_order": "p1"})
|
||||
assert json.loads(mineru["api_key"]) == {"api_key": "m-key", "provider_order": "p1"}
|
||||
paddle = _run_case("PaddleOCR", extra={"api_key": "p-key", "provider_order": "p2"})
|
||||
assert json.loads(paddle["api_key"]) == {"api_key": "p-key", "provider_order": "p2"}
|
||||
|
||||
tencent_req = {
|
||||
"llm_factory": "Tencent Cloud",
|
||||
"llm_name": "model",
|
||||
"model_type": module.LLMType.CHAT.value,
|
||||
"tencent_cloud_sid": "sid",
|
||||
"tencent_cloud_sk": "sk",
|
||||
}
|
||||
|
||||
async def _tencent_request_json():
|
||||
return tencent_req
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", _tencent_request_json)
|
||||
delegated = {}
|
||||
|
||||
async def _fake_set_api_key():
|
||||
delegated["api_key"] = tencent_req.get("api_key")
|
||||
return {"code": 0, "data": "delegated"}
|
||||
|
||||
monkeypatch.setattr(module, "set_api_key", _fake_set_api_key)
|
||||
res = _run(module.add_llm())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == "delegated"
|
||||
assert json.loads(delegated["api_key"]) == {"tencent_cloud_sid": "sid", "tencent_cloud_sk": "sk"}
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_add_llm_model_type_probe_and_persistence_matrix_unit(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
async def _wait_for(coro, *_args, **_kwargs):
|
||||
return await coro
|
||||
|
||||
async def _to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(module.asyncio, "wait_for", _wait_for)
|
||||
monkeypatch.setattr(module.asyncio, "to_thread", _to_thread)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_allowed_llm_factories",
|
||||
lambda: [
|
||||
SimpleNamespace(name=name)
|
||||
for name in [
|
||||
"FEmbFail",
|
||||
"FEmbPass",
|
||||
"FChatFail",
|
||||
"FChatPass",
|
||||
"FRKey",
|
||||
"FRFail",
|
||||
"FImgFail",
|
||||
"FTTSFail",
|
||||
"FOcrFail",
|
||||
"FSttFail",
|
||||
"FUnknown",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
class _EmbeddingFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def encode(self, _texts):
|
||||
return [[]], 1
|
||||
|
||||
class _EmbeddingPass:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def encode(self, _texts):
|
||||
return [[0.5]], 1
|
||||
|
||||
class _ChatFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
async def async_chat(self, *_args, **_kwargs):
|
||||
return "**ERROR**: chat failed", 0
|
||||
|
||||
class _ChatPass:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
async def async_chat(self, *_args, **_kwargs):
|
||||
return "ok", 1
|
||||
|
||||
class _RerankFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def similarity(self, *_args, **_kwargs):
|
||||
return [], 1
|
||||
|
||||
class _CvFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def describe(self, _image_data):
|
||||
return "**ERROR**: image failed", 0
|
||||
|
||||
class _TTSFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def tts(self, _text):
|
||||
raise RuntimeError("tts fail")
|
||||
yield b"x"
|
||||
|
||||
class _OcrFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def check_available(self):
|
||||
return False, "ocr unavailable"
|
||||
|
||||
class _SttFail:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
raise RuntimeError("stt fail")
|
||||
|
||||
class _RerankKeyMap(dict):
|
||||
def __contains__(self, key):
|
||||
if key == "FRKey":
|
||||
return True
|
||||
return super().__contains__(key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key == "FRKey":
|
||||
raise KeyError("rerank key fail")
|
||||
return super().__getitem__(key)
|
||||
|
||||
monkeypatch.setattr(module, "EmbeddingModel", {"FEmbFail": _EmbeddingFail, "FEmbPass": _EmbeddingPass})
|
||||
monkeypatch.setattr(module, "ChatModel", {"FChatFail": _ChatFail, "FChatPass": _ChatPass})
|
||||
monkeypatch.setattr(module, "RerankModel", _RerankKeyMap({"FRFail": _RerankFail}))
|
||||
monkeypatch.setattr(module, "CvModel", {"FImgFail": _CvFail})
|
||||
monkeypatch.setattr(module, "TTSModel", {"FTTSFail": _TTSFail})
|
||||
monkeypatch.setattr(module, "OcrModel", {"FOcrFail": _OcrFail})
|
||||
monkeypatch.setattr(module, "Seq2txtModel", {"FSttFail": _SttFail})
|
||||
|
||||
def _call(req):
|
||||
_set_request_json(monkeypatch, module, req)
|
||||
return _run(module.add_llm())
|
||||
|
||||
res = _call({"llm_factory": "FEmbFail", "llm_name": "m", "model_type": module.LLMType.EMBEDDING.value, "verify": True})
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["success"] is False
|
||||
assert "Fail to access embedding model(m)." in res["data"]["message"]
|
||||
|
||||
res = _call({"llm_factory": "FEmbFail", "llm_name": "m", "model_type": module.LLMType.EMBEDDING.value})
|
||||
assert res["code"] == 400
|
||||
assert "Fail to access embedding model(m)." in res["message"]
|
||||
|
||||
res = _call({"llm_factory": "FChatFail", "llm_name": "m", "model_type": module.LLMType.CHAT.value, "verify": True})
|
||||
assert res["code"] == 0
|
||||
assert "Fail to access model(FChatFail/m)." in res["data"]["message"]
|
||||
|
||||
res = _call({"llm_factory": "FRKey", "llm_name": "m", "model_type": module.LLMType.RERANK.value, "verify": True})
|
||||
assert res["code"] == 0
|
||||
assert "dose not support this model(FRKey/m)" in res["data"]["message"]
|
||||
|
||||
res = _call({"llm_factory": "FRFail", "llm_name": "m", "model_type": module.LLMType.RERANK.value, "verify": True})
|
||||
assert res["code"] == 0
|
||||
assert "Fail to access model(FRFail/m)." in res["data"]["message"]
|
||||
|
||||
res = _call({"llm_factory": "FImgFail", "llm_name": "m", "model_type": module.LLMType.IMAGE2TEXT.value, "verify": True})
|
||||
assert res["code"] == 0
|
||||
assert "Fail to access model(FImgFail/m)." in res["data"]["message"]
|
||||
|
||||
res = _call({"llm_factory": "FTTSFail", "llm_name": "m", "model_type": module.LLMType.TTS.value, "verify": True})
|
||||
assert res["code"] == 0
|
||||
assert "Fail to access model(FTTSFail/m)." in res["data"]["message"]
|
||||
|
||||
res = _call({"llm_factory": "FOcrFail", "llm_name": "m", "model_type": module.LLMType.OCR.value, "verify": True})
|
||||
assert res["code"] == 0
|
||||
assert "Fail to access model(FOcrFail/m)." in res["data"]["message"]
|
||||
|
||||
res = _call({"llm_factory": "FSttFail", "llm_name": "m", "model_type": module.LLMType.SPEECH2TEXT.value, "verify": True})
|
||||
assert res["code"] == 0
|
||||
assert "Fail to access model(FSttFail/m)." in res["data"]["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"llm_factory": "FUnknown", "llm_name": "m", "model_type": "unknown"})
|
||||
with pytest.raises(RuntimeError, match="Unknown model type: unknown"):
|
||||
_run(module.add_llm())
|
||||
|
||||
saved = []
|
||||
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda _filters, _payload: False)
|
||||
monkeypatch.setattr(module.TenantLLMService, "save", lambda **kwargs: saved.append(kwargs) or True)
|
||||
res = _call({"llm_factory": "FChatPass", "llm_name": "m", "model_type": module.LLMType.CHAT.value, "api_key": "k"})
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
assert saved
|
||||
assert saved[0]["llm_factory"] == "FChatPass"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_llm_mutation_routes_unit(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
calls = {"delete": [], "update": []}
|
||||
monkeypatch.setattr(module.TenantLLMService, "filter_delete", lambda filters: calls["delete"].append(filters) or True)
|
||||
monkeypatch.setattr(module.TenantLLMService, "filter_update", lambda filters, payload: calls["update"].append((filters, payload)) or True)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"llm_factory": "OpenAI", "llm_name": "gpt"})
|
||||
res = _run(module.delete_llm())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
|
||||
_set_request_json(monkeypatch, module, {"llm_factory": "OpenAI", "llm_name": "gpt", "status": 0})
|
||||
res = _run(module.enable_llm())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
assert calls["update"][0][1]["status"] == "0"
|
||||
|
||||
_set_request_json(monkeypatch, module, {"llm_factory": "OpenAI"})
|
||||
res = _run(module.delete_factory())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
assert len(calls["delete"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_my_llms_include_details_and_exception_unit(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={"include_details": "true"}))
|
||||
ensure_calls = []
|
||||
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda tenant_id: ensure_calls.append(tenant_id))
|
||||
monkeypatch.setattr(
|
||||
module.TenantLLMService,
|
||||
"query",
|
||||
lambda **_kwargs: [
|
||||
_TenantLLMRow(
|
||||
llm_name="chat-model",
|
||||
llm_factory="FactoryX",
|
||||
model_type="chat",
|
||||
used_tokens=42,
|
||||
api_base="",
|
||||
max_tokens=4096,
|
||||
status="1",
|
||||
)
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(module.LLMFactoriesService, "query", lambda **_kwargs: [SimpleNamespace(name="FactoryX", tags=["tag-a"])])
|
||||
res = module.my_llms()
|
||||
assert res["code"] == 0
|
||||
assert ensure_calls == ["tenant-1"]
|
||||
assert "FactoryX" in res["data"]
|
||||
assert res["data"]["FactoryX"]["tags"] == ["tag-a"]
|
||||
assert res["data"]["FactoryX"]["llm"][0]["used_token"] == 42
|
||||
assert res["data"]["FactoryX"]["llm"][0]["max_tokens"] == 4096
|
||||
|
||||
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: (_ for _ in ()).throw(RuntimeError("my llms boom")))
|
||||
res = module.my_llms()
|
||||
assert res["code"] == 500
|
||||
assert "my llms boom" in res["message"]
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
@ -131,6 +132,16 @@ def _set_request_json(monkeypatch, module, payload):
|
||||
monkeypatch.setattr(module, "get_request_json", _request_json)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_mcp_server_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
@ -197,6 +208,28 @@ def _load_mcp_server_app(monkeypatch):
|
||||
api_utils_mod.get_mcp_tools = _get_mcp_tools
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
web_utils_mod = ModuleType("api.utils.web_utils")
|
||||
|
||||
def _get_float(data, key, default):
|
||||
try:
|
||||
return float(data.get(key, default))
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _safe_json_parse(value):
|
||||
if isinstance(value, (dict, list)):
|
||||
return value
|
||||
if value in (None, ""):
|
||||
return {}
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (TypeError, ValueError):
|
||||
return {}
|
||||
|
||||
web_utils_mod.get_float = _get_float
|
||||
web_utils_mod.safe_json_parse = _safe_json_parse
|
||||
monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
|
||||
|
||||
module_name = "test_mcp_server_app_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "mcp_server_app.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
@ -706,3 +739,159 @@ def test_test_tool_missing_mcp_id(monkeypatch):
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "", "tool_name": "tool_a", "arguments": {"x": 1}})
|
||||
res = _run(module.test_tool.__wrapped__())
|
||||
assert "No MCP server ID provided" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_test_tool_route_matrix_unit(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "", "tool_name": "tool_a", "arguments": {"x": 1}})
|
||||
res = _run(module.test_tool.__wrapped__())
|
||||
assert "No MCP server ID provided" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "", "arguments": {"x": 1}})
|
||||
res = _run(module.test_tool.__wrapped__())
|
||||
assert "Require provide tool name and arguments" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "tool_a", "arguments": {}})
|
||||
res = _run(module.test_tool.__wrapped__())
|
||||
assert "Require provide tool name and arguments" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tool_name": "tool_a", "arguments": {"x": 1}})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
|
||||
res = _run(module.test_tool.__wrapped__())
|
||||
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
|
||||
|
||||
server_other = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="other", variables={})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_other))
|
||||
res = _run(module.test_tool.__wrapped__())
|
||||
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
|
||||
|
||||
server_ok = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_ok))
|
||||
close_calls = []
|
||||
|
||||
async def _thread_pool_exec_success(func, *args):
|
||||
if func is module.close_multiple_mcp_toolcall_sessions:
|
||||
close_calls.append(args[0])
|
||||
return None
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
|
||||
res = _run(module.test_tool.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] == "ok"
|
||||
assert close_calls and len(close_calls[-1]) == 1
|
||||
|
||||
async def _thread_pool_exec_raise(func, *args):
|
||||
if func is module.close_multiple_mcp_toolcall_sessions:
|
||||
return None
|
||||
raise RuntimeError("tool call explode")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_raise)
|
||||
res = _run(module.test_tool.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "tool call explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_cache_tool_route_matrix_unit(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "", "tools": [{"name": "tool_a"}]})
|
||||
res = _run(module.cache_tool.__wrapped__())
|
||||
assert "No MCP server ID provided" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "id1", "tools": [{"name": "tool_a"}]})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
|
||||
res = _run(module.cache_tool.__wrapped__())
|
||||
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
|
||||
|
||||
server_other = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="other", variables={})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_other))
|
||||
res = _run(module.cache_tool.__wrapped__())
|
||||
assert "Cannot find MCP server id1 for user tenant_1" in res["message"]
|
||||
|
||||
server_fail = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_fail))
|
||||
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.cache_tool.__wrapped__())
|
||||
assert "Failed to updated MCP server" in res["message"]
|
||||
|
||||
server_ok = _DummyMCPServer(
|
||||
id="id1",
|
||||
name="srv",
|
||||
url="http://a",
|
||||
server_type="sse",
|
||||
tenant_id="tenant_1",
|
||||
variables={"tools": {"old_tool": {"name": "old_tool"}}},
|
||||
)
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server_ok))
|
||||
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"mcp_id": "id1",
|
||||
"tools": [{"name": "tool_a", "enabled": True}, {"bad": 1}, "x", {"name": "tool_b", "enabled": False}],
|
||||
},
|
||||
)
|
||||
res = _run(module.cache_tool.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert sorted(res["data"].keys()) == ["tool_a", "tool_b"]
|
||||
assert server_ok.variables["tools"]["tool_b"]["enabled"] is False
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_test_mcp_route_matrix_unit(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"})
|
||||
res = _run(module.test_mcp.__wrapped__())
|
||||
assert "Invalid MCP url" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"})
|
||||
res = _run(module.test_mcp.__wrapped__())
|
||||
assert "Unsupported MCP server type" in res["message"]
|
||||
|
||||
close_calls = []
|
||||
|
||||
async def _thread_pool_exec_inner_error(func, *args):
|
||||
if func is module.close_multiple_mcp_toolcall_sessions:
|
||||
close_calls.append(args[0])
|
||||
return None
|
||||
if getattr(func, "__name__", "") == "get_tools":
|
||||
raise RuntimeError("get tools explode")
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error)
|
||||
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
|
||||
res = _run(module.test_mcp.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert "Test MCP error: get tools explode" in res["message"]
|
||||
assert close_calls and len(close_calls[-1]) == 1
|
||||
|
||||
close_calls_success = []
|
||||
|
||||
async def _thread_pool_exec_success(func, *args):
|
||||
if func is module.close_multiple_mcp_toolcall_sessions:
|
||||
close_calls_success.append(args[0])
|
||||
return None
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
|
||||
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
|
||||
res = _run(module.test_mcp.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"][0]["name"] == "tool_a"
|
||||
assert all(tool["enabled"] is True for tool in res["data"])
|
||||
assert close_calls_success and len(close_calls_success[-1]) == 1
|
||||
|
||||
def _raise_session(*_args, **_kwargs):
|
||||
raise RuntimeError("session explode")
|
||||
|
||||
monkeypatch.setattr(module, "MCPToolCallSession", _raise_session)
|
||||
_set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"})
|
||||
res = _run(module.test_mcp.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "session explode" in res["message"]
|
||||
|
||||
@ -0,0 +1,509 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _DummyAtomic:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, _exc_type, _exc, _tb):
|
||||
return False
|
||||
|
||||
|
||||
class _Args(dict):
|
||||
def get(self, key, default=None):
|
||||
return super().get(key, default)
|
||||
|
||||
|
||||
class _EnumValue:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
class _DummyStatusEnum:
|
||||
VALID = _EnumValue("1")
|
||||
|
||||
|
||||
class _DummyRetCode:
|
||||
SUCCESS = 0
|
||||
EXCEPTION_ERROR = 100
|
||||
ARGUMENT_ERROR = 101
|
||||
DATA_ERROR = 102
|
||||
OPERATING_ERROR = 103
|
||||
AUTHENTICATION_ERROR = 109
|
||||
|
||||
|
||||
class _SearchRecord:
|
||||
def __init__(self, search_id="search-1", name="search", search_config=None):
|
||||
self.id = search_id
|
||||
self.name = name
|
||||
self.search_config = {} if search_config is None else dict(search_config)
|
||||
|
||||
def to_dict(self):
|
||||
return {"id": self.id, "name": self.name, "search_config": dict(self.search_config)}
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
async def _request_json():
|
||||
return deepcopy(payload)
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", _request_json)
|
||||
|
||||
|
||||
def _set_request_args(monkeypatch, module, args=None):
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args or {})))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_search_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = SimpleNamespace(args=_Args())
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
misc_utils_mod = ModuleType("common.misc_utils")
|
||||
misc_utils_mod.get_uuid = lambda: "search-uuid-1"
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
|
||||
common_pkg.misc_utils = misc_utils_mod
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
constants_mod.RetCode = _DummyRetCode
|
||||
constants_mod.StatusEnum = _DummyStatusEnum
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
common_pkg.constants = constants_mod
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="tenant-1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
api_pkg.apps = apps_mod
|
||||
|
||||
constants_api_mod = ModuleType("api.constants")
|
||||
constants_api_mod.DATASET_NAME_LIMIT = 255
|
||||
monkeypatch.setitem(sys.modules, "api.constants", constants_api_mod)
|
||||
|
||||
db_pkg = ModuleType("api.db")
|
||||
db_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_pkg)
|
||||
api_pkg.db = db_pkg
|
||||
|
||||
db_models_mod = ModuleType("api.db.db_models")
|
||||
|
||||
class _DummyDB:
|
||||
@staticmethod
|
||||
def atomic():
|
||||
return _DummyAtomic()
|
||||
|
||||
db_models_mod.DB = _DummyDB
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
services_pkg.duplicate_name = lambda _checker, **kwargs: kwargs.get("name", "")
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
search_service_mod = ModuleType("api.db.services.search_service")
|
||||
|
||||
class _SearchService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def save(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def accessible4deletion(_search_id, _user_id):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def update_by_id(_search_id, _req):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_search_id):
|
||||
return True, _SearchRecord(search_id=_search_id, name="updated")
|
||||
|
||||
@staticmethod
|
||||
def get_detail(_search_id):
|
||||
return {"id": _search_id}
|
||||
|
||||
@staticmethod
|
||||
def get_by_tenant_ids(_tenants, _user_id, _page_number, _items_per_page, _orderby, _desc, _keywords):
|
||||
return [], 0
|
||||
|
||||
@staticmethod
|
||||
def delete_by_id(_search_id):
|
||||
return True
|
||||
|
||||
search_service_mod.SearchService = _SearchService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod)
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
|
||||
class _TenantService:
|
||||
@staticmethod
|
||||
def get_by_id(_tenant_id):
|
||||
return True, SimpleNamespace(id=_tenant_id)
|
||||
|
||||
class _UserTenantService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return [SimpleNamespace(tenant_id="tenant-1")]
|
||||
|
||||
user_service_mod.TenantService = _TenantService
|
||||
user_service_mod.UserTenantService = _UserTenantService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
|
||||
utils_pkg = ModuleType("api.utils")
|
||||
utils_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.utils", utils_pkg)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
async def _default_request_json():
|
||||
return {}
|
||||
|
||||
def _get_data_error_result(code=_DummyRetCode.DATA_ERROR, message="Sorry! Data missing!"):
|
||||
return {"code": code, "message": message}
|
||||
|
||||
def _get_json_result(code=_DummyRetCode.SUCCESS, message="success", data=None):
|
||||
return {"code": code, "message": message, "data": data}
|
||||
|
||||
def _server_error_response(error):
|
||||
return {"code": _DummyRetCode.EXCEPTION_ERROR, "message": repr(error)}
|
||||
|
||||
def _validate_request(*_args, **_kwargs):
|
||||
def _decorator(func):
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
def _not_allowed_parameters(*_params):
|
||||
def _decorator(func):
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
api_utils_mod.get_request_json = _default_request_json
|
||||
api_utils_mod.get_data_error_result = _get_data_error_result
|
||||
api_utils_mod.get_json_result = _get_json_result
|
||||
api_utils_mod.server_error_response = _server_error_response
|
||||
api_utils_mod.validate_request = _validate_request
|
||||
api_utils_mod.not_allowed_parameters = _not_allowed_parameters
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
utils_pkg.api_utils = api_utils_mod
|
||||
|
||||
module_name = "test_search_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "search_app.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_route_matrix_unit(monkeypatch):
|
||||
module = _load_search_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": 1})
|
||||
res = _run(module.create())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "must be string" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": " "})
|
||||
res = _run(module.create())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "empty" in res["message"].lower()
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "a" * 256})
|
||||
res = _run(module.create())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "255" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "create-auth-fail"})
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (False, None))
|
||||
res = _run(module.create())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "authorized identity" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (True, SimpleNamespace(id=_tenant_id)))
|
||||
monkeypatch.setattr(module, "duplicate_name", lambda _checker, **kwargs: kwargs["name"] + "_dedup")
|
||||
_set_request_json(monkeypatch, module, {"name": "create-fail", "description": "d"})
|
||||
monkeypatch.setattr(module.SearchService, "save", lambda **_kwargs: False)
|
||||
res = _run(module.create())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "create-ok", "description": "d"})
|
||||
monkeypatch.setattr(module.SearchService, "save", lambda **_kwargs: True)
|
||||
res = _run(module.create())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["search_id"] == "search-uuid-1"
|
||||
|
||||
def _raise_save(**_kwargs):
|
||||
raise RuntimeError("save boom")
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "save", _raise_save)
|
||||
_set_request_json(monkeypatch, module, {"name": "create-exception", "description": "d"})
|
||||
res = _run(module.create())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "save boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_and_detail_route_matrix_unit(monkeypatch):
|
||||
module = _load_search_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": 1, "search_config": {}, "tenant_id": "tenant-1"})
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "must be string" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": " ", "search_config": {}, "tenant_id": "tenant-1"})
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "empty" in res["message"].lower()
|
||||
|
||||
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "a" * 256, "search_config": {}, "tenant_id": "tenant-1"})
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "large than" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "ok", "search_config": {}, "tenant_id": "tenant-1"})
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (False, None))
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "authorized identity" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tenant_id: (True, SimpleNamespace(id=_tenant_id)))
|
||||
monkeypatch.setattr(module.SearchService, "accessible4deletion", lambda _search_id, _user_id: False)
|
||||
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "ok", "search_config": {}, "tenant_id": "tenant-1"})
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
|
||||
assert "authorization" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "accessible4deletion", lambda _search_id, _user_id: True)
|
||||
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [None])
|
||||
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "ok", "search_config": {}, "tenant_id": "tenant-1"})
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "cannot find search" in res["message"].lower()
|
||||
|
||||
existing = _SearchRecord(search_id="s1", name="old-name", search_config={"existing": 1})
|
||||
|
||||
def _query_duplicate(**kwargs):
|
||||
if "id" in kwargs:
|
||||
return [existing]
|
||||
if "name" in kwargs:
|
||||
return [SimpleNamespace(id="dup")]
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "query", _query_duplicate)
|
||||
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "new-name", "search_config": {}, "tenant_id": "tenant-1"})
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "duplicated" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [existing])
|
||||
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "old-name", "search_config": [], "tenant_id": "tenant-1"})
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "json object" in res["message"].lower()
|
||||
|
||||
captured = {}
|
||||
|
||||
def _update_fail(search_id, req):
|
||||
captured["search_id"] = search_id
|
||||
captured["req"] = dict(req)
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "update_by_id", _update_fail)
|
||||
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "old-name", "search_config": {"top_k": 3}, "tenant_id": "tenant-1"})
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "failed to update" in res["message"].lower()
|
||||
assert captured["search_id"] == "s1"
|
||||
assert "search_id" not in captured["req"]
|
||||
assert "tenant_id" not in captured["req"]
|
||||
assert captured["req"]["search_config"] == {"existing": 1, "top_k": 3}
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "update_by_id", lambda _search_id, _req: True)
|
||||
monkeypatch.setattr(module.SearchService, "get_by_id", lambda _search_id: (False, None))
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "failed to fetch" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.SearchService,
|
||||
"get_by_id",
|
||||
lambda _search_id: (True, _SearchRecord(search_id=_search_id, name="old-name", search_config={"existing": 1, "top_k": 3})),
|
||||
)
|
||||
res = _run(module.update())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "s1"
|
||||
|
||||
def _raise_query(**_kwargs):
|
||||
raise RuntimeError("update boom")
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "query", _raise_query)
|
||||
_set_request_json(monkeypatch, module, {"search_id": "s1", "name": "old-name", "search_config": {"top_k": 3}, "tenant_id": "tenant-1"})
|
||||
res = _run(module.update())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "update boom" in res["message"]
|
||||
|
||||
_set_request_args(monkeypatch, module, {"search_id": "s1"})
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-a")])
|
||||
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [])
|
||||
res = module.detail()
|
||||
assert res["code"] == module.RetCode.OPERATING_ERROR
|
||||
assert "permission" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "query", lambda **_kwargs: [SimpleNamespace(id="s1")])
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", lambda _search_id: None)
|
||||
res = module.detail()
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "can't find" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", lambda _search_id: {"id": _search_id, "name": "detail-name"})
|
||||
res = module.detail()
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "s1"
|
||||
|
||||
def _raise_detail(_search_id):
|
||||
raise RuntimeError("detail boom")
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", _raise_detail)
|
||||
res = module.detail()
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "detail boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_and_rm_route_matrix_unit(monkeypatch):
|
||||
module = _load_search_app(monkeypatch)
|
||||
|
||||
_set_request_args(
|
||||
monkeypatch,
|
||||
module,
|
||||
{"keywords": "k", "page": "1", "page_size": "2", "orderby": "create_time", "desc": "false"},
|
||||
)
|
||||
_set_request_json(monkeypatch, module, {"owner_ids": []})
|
||||
monkeypatch.setattr(
|
||||
module.SearchService,
|
||||
"get_by_tenant_ids",
|
||||
lambda _tenants, _uid, _page, _size, _orderby, _desc, _keywords: ([{"id": "a", "tenant_id": "tenant-1"}], 1),
|
||||
)
|
||||
res = _run(module.list_search_app())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 1
|
||||
assert res["data"]["search_apps"][0]["id"] == "a"
|
||||
|
||||
_set_request_args(
|
||||
monkeypatch,
|
||||
module,
|
||||
{"keywords": "k", "page": "1", "page_size": "1", "orderby": "create_time", "desc": "true"},
|
||||
)
|
||||
_set_request_json(monkeypatch, module, {"owner_ids": ["tenant-1"]})
|
||||
monkeypatch.setattr(
|
||||
module.SearchService,
|
||||
"get_by_tenant_ids",
|
||||
lambda _tenants, _uid, _page, _size, _orderby, _desc, _keywords: (
|
||||
[{"id": "x", "tenant_id": "tenant-1"}, {"id": "y", "tenant_id": "tenant-2"}],
|
||||
2,
|
||||
),
|
||||
)
|
||||
res = _run(module.list_search_app())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 1
|
||||
assert len(res["data"]["search_apps"]) == 1
|
||||
assert res["data"]["search_apps"][0]["tenant_id"] == "tenant-1"
|
||||
|
||||
def _raise_list(*_args, **_kwargs):
|
||||
raise RuntimeError("list boom")
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "get_by_tenant_ids", _raise_list)
|
||||
_set_request_json(monkeypatch, module, {"owner_ids": []})
|
||||
res = _run(module.list_search_app())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "list boom" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"search_id": "search-1"})
|
||||
monkeypatch.setattr(module.SearchService, "accessible4deletion", lambda _search_id, _user_id: False)
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR
|
||||
assert "authorization" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "accessible4deletion", lambda _search_id, _user_id: True)
|
||||
monkeypatch.setattr(module.SearchService, "delete_by_id", lambda _search_id: False)
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR
|
||||
assert "failed to delete" in res["message"].lower()
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "delete_by_id", lambda _search_id: True)
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
|
||||
def _raise_delete(_search_id):
|
||||
raise RuntimeError("rm boom")
|
||||
|
||||
monkeypatch.setattr(module.SearchService, "delete_by_id", _raise_delete)
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "rm boom" in res["message"]
|
||||
@ -0,0 +1,322 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _ExprField:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.name, other)
|
||||
|
||||
|
||||
class _DummyAPITokenModel:
|
||||
tenant_id = _ExprField("tenant_id")
|
||||
token = _ExprField("token")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
def _load_system_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.login_required = lambda fn: fn
|
||||
apps_mod.current_user = SimpleNamespace(id="user-1")
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
settings_mod = ModuleType("common.settings")
|
||||
settings_mod.docStoreConn = SimpleNamespace(health=lambda: {"type": "doc", "status": "green"})
|
||||
settings_mod.STORAGE_IMPL = SimpleNamespace(health=lambda: True)
|
||||
settings_mod.STORAGE_IMPL_TYPE = "MINIO"
|
||||
settings_mod.DATABASE_TYPE = "MYSQL"
|
||||
settings_mod.REGISTER_ENABLED = True
|
||||
common_pkg.settings = settings_mod
|
||||
monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
|
||||
|
||||
versions_mod = ModuleType("common.versions")
|
||||
versions_mod.get_ragflow_version = lambda: "0.0.0-unit"
|
||||
monkeypatch.setitem(sys.modules, "common.versions", versions_mod)
|
||||
|
||||
time_utils_mod = ModuleType("common.time_utils")
|
||||
time_utils_mod.current_timestamp = lambda: 111
|
||||
time_utils_mod.datetime_format = lambda _dt: "2026-01-01 00:00:00"
|
||||
monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
api_utils_mod.get_json_result = lambda data=None, message="success", code=0: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
api_utils_mod.get_data_error_result = lambda message="", code=102, data=None: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
api_utils_mod.server_error_response = lambda exc: {
|
||||
"code": 100,
|
||||
"message": repr(exc),
|
||||
"data": None,
|
||||
}
|
||||
api_utils_mod.generate_confirmation_token = lambda: "ragflow-abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
api_service_mod = ModuleType("api.db.services.api_service")
|
||||
api_service_mod.APITokenService = SimpleNamespace(
|
||||
save=lambda **_kwargs: True,
|
||||
query=lambda **_kwargs: [],
|
||||
filter_update=lambda *_args, **_kwargs: True,
|
||||
filter_delete=lambda *_args, **_kwargs: True,
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.api_service", api_service_mod)
|
||||
|
||||
kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
|
||||
kb_service_mod.KnowledgebaseService = SimpleNamespace(get_by_id=lambda _kb_id: True)
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
user_service_mod.UserTenantService = SimpleNamespace(
|
||||
query=lambda **_kwargs: [SimpleNamespace(role="owner", tenant_id="tenant-1")]
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
|
||||
db_models_mod = ModuleType("api.db.db_models")
|
||||
db_models_mod.APIToken = _DummyAPITokenModel
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
rag_pkg = ModuleType("rag")
|
||||
rag_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
|
||||
|
||||
rag_utils_pkg = ModuleType("rag.utils")
|
||||
rag_utils_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg)
|
||||
|
||||
redis_mod = ModuleType("rag.utils.redis_conn")
|
||||
redis_mod.REDIS_CONN = SimpleNamespace(
|
||||
health=lambda: True,
|
||||
smembers=lambda *_args, **_kwargs: set(),
|
||||
zrangebyscore=lambda *_args, **_kwargs: [],
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod)
|
||||
|
||||
health_utils_mod = ModuleType("api.utils.health_utils")
|
||||
health_utils_mod.run_health_checks = lambda: ({"status": "ok"}, True)
|
||||
health_utils_mod.get_oceanbase_status = lambda: {"status": "alive"}
|
||||
monkeypatch.setitem(sys.modules, "api.utils.health_utils", health_utils_mod)
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.jsonify = lambda payload: payload
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "system_app.py"
|
||||
spec = importlib.util.spec_from_file_location("test_system_routes_unit_module", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, "test_system_routes_unit_module", module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_status_branch_matrix_unit(monkeypatch):
|
||||
module = _load_system_module(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.settings, "docStoreConn", SimpleNamespace(health=lambda: {"type": "es", "status": "green"}))
|
||||
monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(health=lambda: True))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: True)
|
||||
monkeypatch.setattr(module.REDIS_CONN, "health", lambda: True)
|
||||
monkeypatch.setattr(module.REDIS_CONN, "smembers", lambda _key: {"executor-1"})
|
||||
monkeypatch.setattr(module.REDIS_CONN, "zrangebyscore", lambda *_args, **_kwargs: ['{"beat": 1}'])
|
||||
|
||||
res = module.status()
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["doc_engine"]["status"] == "green"
|
||||
assert res["data"]["storage"]["status"] == "green"
|
||||
assert res["data"]["database"]["status"] == "green"
|
||||
assert res["data"]["redis"]["status"] == "green"
|
||||
assert res["data"]["task_executor_heartbeats"]["executor-1"][0]["beat"] == 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.settings,
|
||||
"docStoreConn",
|
||||
SimpleNamespace(health=lambda: (_ for _ in ()).throw(RuntimeError("doc down"))),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.settings,
|
||||
"STORAGE_IMPL",
|
||||
SimpleNamespace(health=lambda: (_ for _ in ()).throw(RuntimeError("storage down"))),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.KnowledgebaseService,
|
||||
"get_by_id",
|
||||
lambda _kb_id: (_ for _ in ()).throw(RuntimeError("db down")),
|
||||
)
|
||||
monkeypatch.setattr(module.REDIS_CONN, "health", lambda: False)
|
||||
monkeypatch.setattr(module.REDIS_CONN, "smembers", lambda _key: (_ for _ in ()).throw(RuntimeError("hb down")))
|
||||
|
||||
res = module.status()
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["doc_engine"]["status"] == "red"
|
||||
assert "doc down" in res["data"]["doc_engine"]["error"]
|
||||
assert res["data"]["storage"]["status"] == "red"
|
||||
assert "storage down" in res["data"]["storage"]["error"]
|
||||
assert res["data"]["database"]["status"] == "red"
|
||||
assert "db down" in res["data"]["database"]["error"]
|
||||
assert res["data"]["redis"]["status"] == "red"
|
||||
assert "Lost connection!" in res["data"]["redis"]["error"]
|
||||
assert res["data"]["task_executor_heartbeats"] == {}
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_healthz_and_oceanbase_status_matrix_unit(monkeypatch):
|
||||
module = _load_system_module(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module, "run_health_checks", lambda: ({"status": "ok"}, True))
|
||||
payload, status = module.healthz()
|
||||
assert status == 200
|
||||
assert payload["status"] == "ok"
|
||||
|
||||
monkeypatch.setattr(module, "run_health_checks", lambda: ({"status": "degraded"}, False))
|
||||
payload, status = module.healthz()
|
||||
assert status == 500
|
||||
assert payload["status"] == "degraded"
|
||||
|
||||
monkeypatch.setattr(module, "get_oceanbase_status", lambda: {"status": "alive", "latency_ms": 8})
|
||||
res = module.oceanbase_status()
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["status"] == "alive"
|
||||
|
||||
monkeypatch.setattr(module, "get_oceanbase_status", lambda: (_ for _ in ()).throw(RuntimeError("ocean boom")))
|
||||
res = module.oceanbase_status()
|
||||
assert res["code"] == 500
|
||||
assert res["data"]["status"] == "error"
|
||||
assert "ocean boom" in res["data"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_system_token_routes_matrix_unit(monkeypatch):
|
||||
module = _load_system_module(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
|
||||
res = module.new_token()
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="owner", tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(module.APITokenService, "save", lambda **_kwargs: False)
|
||||
res = module.new_token()
|
||||
assert res["message"] == "Fail to new a dialog!"
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("tenant query boom")))
|
||||
res = module.new_token()
|
||||
assert res["code"] == 100
|
||||
assert "tenant query boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
|
||||
res = module.token_list()
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
class _Token:
|
||||
def __init__(self, token, beta):
|
||||
self.token = token
|
||||
self.beta = beta
|
||||
|
||||
def to_dict(self):
|
||||
return {"token": self.token, "beta": self.beta}
|
||||
|
||||
filter_updates = []
|
||||
monkeypatch.setattr(module, "generate_confirmation_token", lambda: "ragflow-abcdefghijklmnopqrstuvwxyz0123456789")
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="owner", tenant_id="tenant-9")])
|
||||
monkeypatch.setattr(module.APITokenService, "query", lambda **_kwargs: [_Token("tok-1", ""), _Token("tok-2", "beta-2")])
|
||||
monkeypatch.setattr(module.APITokenService, "filter_update", lambda conds, payload: filter_updates.append((conds, payload)))
|
||||
res = module.token_list()
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]) == 2
|
||||
assert len(res["data"][0]["beta"]) == 32
|
||||
assert res["data"][1]["beta"] == "beta-2"
|
||||
assert len(filter_updates) == 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.APITokenService,
|
||||
"query",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("token list boom")),
|
||||
)
|
||||
res = module.token_list()
|
||||
assert res["code"] == 100
|
||||
assert "token list boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
|
||||
res = module.rm("tok-1")
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
deleted = []
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="owner", tenant_id="tenant-3")])
|
||||
monkeypatch.setattr(module.APITokenService, "filter_delete", lambda conds: deleted.append(conds))
|
||||
res = module.rm("tok-1")
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
assert deleted
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.APITokenService,
|
||||
"filter_delete",
|
||||
lambda _conds: (_ for _ in ()).throw(RuntimeError("delete boom")),
|
||||
)
|
||||
res = module.rm("tok-1")
|
||||
assert res["code"] == 100
|
||||
assert "delete boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_config_returns_register_enabled_unit(monkeypatch):
|
||||
module = _load_system_module(monkeypatch)
|
||||
monkeypatch.setattr(module.settings, "REGISTER_ENABLED", False)
|
||||
res = module.get_config()
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["registerEnabled"] is False
|
||||
@ -0,0 +1,318 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _ExprField:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.name, other)
|
||||
|
||||
|
||||
class _Invitee:
|
||||
def __init__(self, user_id="invitee-1", email="invitee@example.com"):
|
||||
self.id = user_id
|
||||
self.email = email
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"avatar": "avatar-url",
|
||||
"email": self.email,
|
||||
"nickname": "Invitee",
|
||||
"password": "ignored",
|
||||
}
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))
|
||||
|
||||
|
||||
def _load_tenant_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="tenant-1", email="owner@example.com")
|
||||
apps_mod.login_required = lambda fn: fn
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
db_mod = ModuleType("api.db")
|
||||
db_mod.UserTenantRole = SimpleNamespace(NORMAL="normal", OWNER="owner", INVITE="invite")
|
||||
monkeypatch.setitem(sys.modules, "api.db", db_mod)
|
||||
|
||||
db_models_mod = ModuleType("api.db.db_models")
|
||||
db_models_mod.UserTenant = type(
|
||||
"UserTenant",
|
||||
(),
|
||||
{
|
||||
"tenant_id": _ExprField("tenant_id"),
|
||||
"user_id": _ExprField("user_id"),
|
||||
},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
|
||||
class _UserTenantService:
|
||||
@staticmethod
|
||||
def get_by_tenant_id(_tenant_id):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def save(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_delete(_conditions):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_tenants_by_user_id(_user_id):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def filter_update(_conditions, _payload):
|
||||
return True
|
||||
|
||||
class _UserService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_user_id):
|
||||
return False, None
|
||||
|
||||
user_service_mod.UserTenantService = _UserTenantService
|
||||
user_service_mod.UserService = _UserService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "message": message, "data": data}
|
||||
api_utils_mod.get_data_error_result = lambda message="": {"code": 102, "message": message, "data": False}
|
||||
api_utils_mod.server_error_response = lambda exc: {"code": 100, "message": repr(exc), "data": False}
|
||||
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
|
||||
api_utils_mod.get_request_json = lambda: _AwaitableValue({})
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
web_utils_mod = ModuleType("api.utils.web_utils")
|
||||
web_utils_mod.send_invite_email = lambda **_kwargs: {"ok": True}
|
||||
monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod)
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
constants_mod.RetCode = SimpleNamespace(AUTHENTICATION_ERROR=401, SERVER_ERROR=500, DATA_ERROR=102)
|
||||
constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value=1))
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
misc_utils_mod = ModuleType("common.misc_utils")
|
||||
misc_utils_mod.get_uuid = lambda: "uuid-1"
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
|
||||
|
||||
time_utils_mod = ModuleType("common.time_utils")
|
||||
time_utils_mod.delta_seconds = lambda _value: 0
|
||||
monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod)
|
||||
|
||||
settings_mod = ModuleType("common.settings")
|
||||
settings_mod.MAIL_FRONTEND_URL = "https://frontend.example/invite"
|
||||
monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
|
||||
common_pkg.settings = settings_mod
|
||||
|
||||
sys.modules.pop("test_tenant_app_unit_module", None)
|
||||
module_path = repo_root / "api" / "apps" / "tenant_app.py"
|
||||
spec = importlib.util.spec_from_file_location("test_tenant_app_unit_module", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, "test_tenant_app_unit_module", module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_user_list_auth_success_exception_matrix_unit(monkeypatch):
|
||||
module = _load_tenant_module(monkeypatch)
|
||||
|
||||
module.current_user.id = "other-user"
|
||||
res = module.user_list("tenant-1")
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
|
||||
assert res["message"] == "No authorization.", res
|
||||
|
||||
module.current_user.id = "tenant-1"
|
||||
monkeypatch.setattr(
|
||||
module.UserTenantService,
|
||||
"get_by_tenant_id",
|
||||
lambda _tenant_id: [{"id": "u1", "update_date": "2024-01-01 00:00:00"}],
|
||||
)
|
||||
monkeypatch.setattr(module, "delta_seconds", lambda _value: 42)
|
||||
res = module.user_list("tenant-1")
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["delta_seconds"] == 42, res
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "get_by_tenant_id", lambda _tenant_id: (_ for _ in ()).throw(RuntimeError("list boom")))
|
||||
res = module.user_list("tenant-1")
|
||||
assert res["code"] == 100, res
|
||||
assert "list boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_invite_role_and_email_failure_matrix_unit(monkeypatch):
|
||||
module = _load_tenant_module(monkeypatch)
|
||||
|
||||
module.current_user.id = "other-user"
|
||||
_set_request_json(monkeypatch, module, {"email": "invitee@example.com"})
|
||||
res = _run(module.create("tenant-1"))
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
|
||||
assert res["message"] == "No authorization.", res
|
||||
|
||||
module.current_user.id = "tenant-1"
|
||||
monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [])
|
||||
res = _run(module.create("tenant-1"))
|
||||
assert res["message"] == "User not found.", res
|
||||
|
||||
invitee = _Invitee()
|
||||
monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [invitee])
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role=module.UserTenantRole.NORMAL)])
|
||||
res = _run(module.create("tenant-1"))
|
||||
assert "already in the team." in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role=module.UserTenantRole.OWNER)])
|
||||
res = _run(module.create("tenant-1"))
|
||||
assert "owner of the team." in res["message"], res
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="strange-role")])
|
||||
res = _run(module.create("tenant-1"))
|
||||
assert "role: strange-role is invalid." in res["message"], res
|
||||
|
||||
saved = []
|
||||
scheduled = []
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
|
||||
monkeypatch.setattr(module.UserTenantService, "save", lambda **kwargs: saved.append(kwargs) or True)
|
||||
monkeypatch.setattr(module.UserService, "get_by_id", lambda _user_id: (True, SimpleNamespace(nickname="Inviter Nick")))
|
||||
monkeypatch.setattr(module, "send_invite_email", lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(module.asyncio, "create_task", lambda payload: scheduled.append(payload) or SimpleNamespace())
|
||||
res = _run(module.create("tenant-1"))
|
||||
assert res["code"] == 0, res
|
||||
assert saved and saved[-1]["role"] == module.UserTenantRole.INVITE, saved
|
||||
assert scheduled and scheduled[-1]["inviter"] == "Inviter Nick", scheduled
|
||||
assert sorted(res["data"].keys()) == ["avatar", "email", "id", "nickname"], res
|
||||
|
||||
monkeypatch.setattr(module.asyncio, "create_task", lambda _payload: (_ for _ in ()).throw(RuntimeError("send boom")))
|
||||
res = _run(module.create("tenant-1"))
|
||||
assert res["code"] == module.RetCode.SERVER_ERROR, res
|
||||
assert "Failed to send invite email." in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_rm_and_tenant_list_matrix_unit(monkeypatch):
|
||||
module = _load_tenant_module(monkeypatch)
|
||||
|
||||
module.current_user.id = "outsider"
|
||||
res = module.rm("tenant-1", "user-2")
|
||||
assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res
|
||||
assert res["message"] == "No authorization.", res
|
||||
|
||||
module.current_user.id = "tenant-1"
|
||||
deleted = []
|
||||
monkeypatch.setattr(module.UserTenantService, "filter_delete", lambda conditions: deleted.append(conditions) or True)
|
||||
res = module.rm("tenant-1", "user-2")
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"] is True, res
|
||||
assert deleted, "filter_delete should be called"
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "filter_delete", lambda _conditions: (_ for _ in ()).throw(RuntimeError("rm boom")))
|
||||
res = module.rm("tenant-1", "user-2")
|
||||
assert res["code"] == 100, res
|
||||
assert "rm boom" in res["message"], res
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.UserTenantService,
|
||||
"get_tenants_by_user_id",
|
||||
lambda _user_id: [{"id": "tenant-1", "update_date": "2024-01-01 00:00:00"}],
|
||||
)
|
||||
monkeypatch.setattr(module, "delta_seconds", lambda _value: 9)
|
||||
res = module.tenant_list()
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["delta_seconds"] == 9, res
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "get_tenants_by_user_id", lambda _user_id: (_ for _ in ()).throw(RuntimeError("tenant boom")))
|
||||
res = module.tenant_list()
|
||||
assert res["code"] == 100, res
|
||||
assert "tenant boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_agree_success_and_exception_unit(monkeypatch):
|
||||
module = _load_tenant_module(monkeypatch)
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(module.UserTenantService, "filter_update", lambda conditions, payload: calls.append((conditions, payload)) or True)
|
||||
res = module.agree("tenant-1")
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"] is True, res
|
||||
assert calls and calls[-1][1]["role"] == module.UserTenantRole.NORMAL
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "filter_update", lambda _conditions, _payload: (_ for _ in ()).throw(RuntimeError("agree boom")))
|
||||
res = module.agree("tenant-1")
|
||||
assert res["code"] == 100, res
|
||||
assert "agree boom" in res["message"], res
|
||||
1324
test/testcases/test_web_api/test_user_app/test_user_app_unit.py
Normal file
1324
test/testcases/test_web_api/test_user_app/test_user_app_unit.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user