tests: improve RAGFlow coverage based on Codecov report (#13200)

### What problem does this PR solve?

Codecov’s coverage report shows that several RAGFlow code paths are
currently untested or under-tested. This makes it easier for regressions
to slip in during refactors and feature work.
This PR adds targeted automated tests to cover the files and branches
highlighted by Codecov, improving confidence in core behavior while
keeping runtime functionality unchanged.

### Type of change

- [x] Other (please describe): Test coverage improvement (adds/extends
unit and integration tests to address Codecov-reported gaps)
This commit is contained in:
6ba3i
2026-02-25 19:12:11 +08:00
committed by GitHub
parent 2a5ddf064d
commit 38011f2c16
56 changed files with 11453 additions and 17 deletions

View File

@ -292,7 +292,7 @@ def knowledge_graph(auth, dataset_id, params=None, *, headers=HEADERS):
def delete_knowledge_graph(auth, dataset_id, payload=None, *, headers=HEADERS, data=None):
res = requests.delete(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/delete_knowledge_graph", headers=headers, auth=auth, json=payload, data=data)
res = requests.delete(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/knowledge_graph", headers=headers, auth=auth, json=payload, data=data)
return res.json()
@ -434,6 +434,11 @@ def update_chunk(auth, payload=None, *, headers=HEADERS):
return res.json()
def switch_chunks(auth, payload=None, *, headers=HEADERS):
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/switch", headers=headers, auth=auth, json=payload)
return res.json()
def delete_chunks(auth, payload=None, *, headers=HEADERS):
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/rm", headers=headers, auth=auth, json=payload)
return res.json()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,247 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _ExprField:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return (self.name, other)
class _DummyAPITokenModel:
tenant_id = _ExprField("tenant_id")
token = _ExprField("token")
def _run(coro):
return asyncio.run(coro)
def _load_api_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
quart_mod = ModuleType("quart")
quart_mod.request = SimpleNamespace(args={})
monkeypatch.setitem(sys.modules, "quart", quart_mod)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.login_required = lambda fn: fn
apps_mod.current_user = SimpleNamespace(id="user-1")
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
async def _get_request_json():
return {}
api_utils_mod.generate_confirmation_token = lambda: "token-123"
api_utils_mod.get_request_json = _get_request_json
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
"code": code,
"message": message,
"data": data,
}
api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
"code": code,
"message": message,
"data": data,
}
api_utils_mod.server_error_response = lambda exc: {
"code": 500,
"message": str(exc),
"data": None,
}
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
api_service_mod = ModuleType("api.db.services.api_service")
class _StubAPITokenService:
@staticmethod
def save(**_kwargs):
return True
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def filter_delete(_conds):
return True
class _StubAPI4ConversationService:
@staticmethod
def stats(*_args, **_kwargs):
return []
api_service_mod.APITokenService = _StubAPITokenService
api_service_mod.API4ConversationService = _StubAPI4ConversationService
monkeypatch.setitem(sys.modules, "api.db.services.api_service", api_service_mod)
user_service_mod = ModuleType("api.db.services.user_service")
class _StubUserTenantService:
@staticmethod
def query(**_kwargs):
return [SimpleNamespace(tenant_id="tenant-1")]
user_service_mod.UserTenantService = _StubUserTenantService
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
db_models_mod = ModuleType("api.db.db_models")
db_models_mod.APIToken = _DummyAPITokenModel
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
time_utils_mod = ModuleType("common.time_utils")
time_utils_mod.current_timestamp = lambda: 123
time_utils_mod.datetime_format = lambda _dt: "2026-01-01 00:00:00"
monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod)
module_path = repo_root / "api" / "apps" / "api_app.py"
spec = importlib.util.spec_from_file_location("test_api_tokens_unit_module", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_new_token_branches_and_error_paths(monkeypatch):
module = _load_api_app(monkeypatch)
async def req_canvas():
return {"canvas_id": "canvas-1"}
monkeypatch.setattr(module, "get_request_json", req_canvas)
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
res = _run(module.new_token())
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
monkeypatch.setattr(module.APITokenService, "save", lambda **_kwargs: True)
res = _run(module.new_token())
assert res["code"] == 0
assert res["data"]["tenant_id"] == "tenant-1"
assert res["data"]["dialog_id"] == "canvas-1"
assert res["data"]["source"] == "agent"
monkeypatch.setattr(module.APITokenService, "save", lambda **_kwargs: False)
res = _run(module.new_token())
assert res["message"] == "Fail to new a dialog!"
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("query failed")))
res = _run(module.new_token())
assert res["code"] == 500
assert "query failed" in res["message"]
@pytest.mark.p2
def test_token_list_tenant_guard_and_exception(monkeypatch):
module = _load_api_app(monkeypatch)
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module, "request", SimpleNamespace(args={"dialog_id": "d1"}))
res = module.token_list()
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
res = module.token_list()
assert res["code"] == 500
assert "canvas_id" in res["message"]
@pytest.mark.p2
def test_rm_exception_path(monkeypatch):
module = _load_api_app(monkeypatch)
async def req_rm():
return {"tokens": ["tok-1"], "tenant_id": "tenant-1"}
monkeypatch.setattr(module, "get_request_json", req_rm)
monkeypatch.setattr(
module.APITokenService,
"filter_delete",
lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("delete failed")),
)
res = _run(module.rm())
assert res["code"] == 500
assert "delete failed" in res["message"]
@pytest.mark.p2
def test_stats_aggregation_and_error_paths(monkeypatch):
module = _load_api_app(monkeypatch)
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
res = module.stats()
assert res["message"] == "Tenant not found!"
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
monkeypatch.setattr(module, "request", SimpleNamespace(args={"canvas_id": "canvas-1"}))
monkeypatch.setattr(
module.API4ConversationService,
"stats",
lambda *_args, **_kwargs: [
{
"dt": "2026-01-01",
"pv": 3,
"uv": 2,
"tokens": 100,
"duration": 9.9,
"round": 1,
"thumb_up": 0,
}
],
)
res = module.stats()
assert res["code"] == 0
assert res["data"]["pv"] == [("2026-01-01", 3)]
assert res["data"]["uv"] == [("2026-01-01", 2)]
assert res["data"]["round"] == [("2026-01-01", 1)]
assert res["data"]["thumb_up"] == [("2026-01-01", 0)]
assert res["data"]["tokens"] == [("2026-01-01", 0.1)]
assert res["data"]["speed"] == [("2026-01-01", 10.0)]
monkeypatch.setattr(
module.API4ConversationService,
"stats",
lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("stats failed")),
)
res = module.stats()
assert res["code"] == 500
assert "stats failed" in res["message"]

View File

@ -0,0 +1,484 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _FakeResponse:
def __init__(self, payload=None, err=None):
self._payload = payload or {}
self._err = err
def raise_for_status(self):
if self._err:
raise self._err
def json(self):
return self._payload
class _DummyJwkClient:
def __init__(self, _jwks_uri):
self._key = "dummy-signing-key"
def get_signing_key_from_jwt(self, _id_token):
return SimpleNamespace(key=self._key)
def _load_auth_modules(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
apps_pkg = ModuleType("api.apps")
apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
auth_pkg = ModuleType("api.apps.auth")
auth_pkg.__path__ = [str(repo_root / "api" / "apps" / "auth")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
monkeypatch.setitem(sys.modules, "api.apps.auth", auth_pkg)
for mod_name in ["api.apps.auth.oauth", "api.apps.auth.oidc"]:
sys.modules.pop(mod_name, None)
oauth_path = repo_root / "api" / "apps" / "auth" / "oauth.py"
oauth_spec = importlib.util.spec_from_file_location("api.apps.auth.oauth", oauth_path)
oauth_module = importlib.util.module_from_spec(oauth_spec)
monkeypatch.setitem(sys.modules, "api.apps.auth.oauth", oauth_module)
oauth_spec.loader.exec_module(oauth_module)
oidc_path = repo_root / "api" / "apps" / "auth" / "oidc.py"
oidc_spec = importlib.util.spec_from_file_location("api.apps.auth.oidc", oidc_path)
oidc_module = importlib.util.module_from_spec(oidc_spec)
monkeypatch.setitem(sys.modules, "api.apps.auth.oidc", oidc_module)
oidc_spec.loader.exec_module(oidc_module)
return oauth_module, oidc_module
def _load_github_module(monkeypatch):
_load_auth_modules(monkeypatch)
repo_root = Path(__file__).resolve().parents[4]
sys.modules.pop("api.apps.auth.github", None)
github_path = repo_root / "api" / "apps" / "auth" / "github.py"
github_spec = importlib.util.spec_from_file_location("api.apps.auth.github", github_path)
github_module = importlib.util.module_from_spec(github_spec)
monkeypatch.setitem(sys.modules, "api.apps.auth.github", github_module)
github_spec.loader.exec_module(github_module)
return github_module
def _load_auth_init_module(monkeypatch):
_load_auth_modules(monkeypatch)
repo_root = Path(__file__).resolve().parents[4]
github_mod = ModuleType("api.apps.auth.github")
class _StubGithubOAuthClient:
def __init__(self, config):
self.config = config
github_mod.GithubOAuthClient = _StubGithubOAuthClient
monkeypatch.setitem(sys.modules, "api.apps.auth.github", github_mod)
init_path = repo_root / "api" / "apps" / "auth" / "__init__.py"
init_spec = importlib.util.spec_from_file_location(
"api.apps.auth",
init_path,
submodule_search_locations=[str(repo_root / "api" / "apps" / "auth")],
)
init_module = importlib.util.module_from_spec(init_spec)
monkeypatch.setitem(sys.modules, "api.apps.auth", init_module)
init_spec.loader.exec_module(init_module)
return init_module
def _base_config():
return {
"issuer": "https://issuer.example",
"client_id": "client-1",
"client_secret": "secret-1",
"redirect_uri": "https://app.example/callback",
}
def _metadata(issuer):
return {
"issuer": issuer,
"jwks_uri": f"{issuer}/jwks",
"authorization_endpoint": f"{issuer}/authorize",
"token_endpoint": f"{issuer}/token",
"userinfo_endpoint": f"{issuer}/userinfo",
}
def _make_client(monkeypatch, oidc_module):
monkeypatch.setattr(oidc_module.OIDCClient, "_load_oidc_metadata", staticmethod(lambda issuer: _metadata(issuer)))
return oidc_module.OIDCClient(_base_config())
@pytest.mark.p2
def test_oidc_init_requires_issuer(monkeypatch):
_, oidc_module = _load_auth_modules(monkeypatch)
with pytest.raises(ValueError) as exc_info:
oidc_module.OIDCClient({"client_id": "cid"})
assert str(exc_info.value) == "Missing issuer in configuration."
@pytest.mark.p2
def test_oidc_init_loads_metadata_and_sets_endpoints(monkeypatch):
_, oidc_module = _load_auth_modules(monkeypatch)
monkeypatch.setattr(oidc_module.OIDCClient, "_load_oidc_metadata", staticmethod(lambda issuer: _metadata(issuer)))
client = oidc_module.OIDCClient(_base_config())
assert client.issuer == "https://issuer.example"
assert client.jwks_uri == "https://issuer.example/jwks"
assert client.authorization_url == "https://issuer.example/authorize"
assert client.token_url == "https://issuer.example/token"
assert client.userinfo_url == "https://issuer.example/userinfo"
@pytest.mark.p2
def test_load_oidc_metadata_success_and_wraps_failure(monkeypatch):
_, oidc_module = _load_auth_modules(monkeypatch)
calls = {}
def _ok_sync_request(method, url, timeout):
calls.update({"method": method, "url": url, "timeout": timeout})
return _FakeResponse(_metadata("https://issuer.example"))
monkeypatch.setattr(oidc_module, "sync_request", _ok_sync_request)
metadata = oidc_module.OIDCClient._load_oidc_metadata("https://issuer.example")
assert metadata["jwks_uri"] == "https://issuer.example/jwks"
assert calls == {
"method": "GET",
"url": "https://issuer.example/.well-known/openid-configuration",
"timeout": 7,
}
def _boom_sync_request(*_args, **_kwargs):
raise RuntimeError("metadata boom")
monkeypatch.setattr(oidc_module, "sync_request", _boom_sync_request)
with pytest.raises(ValueError) as exc_info:
oidc_module.OIDCClient._load_oidc_metadata("https://issuer.example")
assert str(exc_info.value) == "Failed to fetch OIDC metadata: metadata boom"
@pytest.mark.p2
def test_parse_id_token_success_and_error(monkeypatch):
_, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(monkeypatch, oidc_module)
monkeypatch.setattr(oidc_module.jwt, "get_unverified_header", lambda _token: {})
seen = {}
class _JwkClient(_DummyJwkClient):
def __init__(self, jwks_uri):
super().__init__(jwks_uri)
seen["jwks_uri"] = jwks_uri
def get_signing_key_from_jwt(self, id_token):
seen["id_token"] = id_token
return super().get_signing_key_from_jwt(id_token)
monkeypatch.setattr(oidc_module.jwt, "PyJWKClient", _JwkClient)
def _decode(id_token, key, algorithms, audience, issuer):
seen.update(
{
"decode_id_token": id_token,
"decode_key": key,
"algorithms": algorithms,
"audience": audience,
"issuer": issuer,
}
)
return {"sub": "user-1", "email": "id@example.com"}
monkeypatch.setattr(oidc_module.jwt, "decode", _decode)
parsed = client.parse_id_token("id-token-1")
assert parsed["sub"] == "user-1"
assert seen["jwks_uri"] == "https://issuer.example/jwks"
assert seen["decode_key"] == "dummy-signing-key"
assert seen["algorithms"] == ["RS256"]
assert seen["audience"] == "client-1"
assert seen["issuer"] == "https://issuer.example"
def _raise_decode(*_args, **_kwargs):
raise RuntimeError("decode boom")
monkeypatch.setattr(oidc_module.jwt, "decode", _raise_decode)
with pytest.raises(ValueError) as exc_info:
client.parse_id_token("id-token-2")
assert str(exc_info.value) == "Error parsing ID Token: decode boom"
@pytest.mark.p2
def test_fetch_user_info_merges_id_token_and_oauth_userinfo(monkeypatch):
oauth_module, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(monkeypatch, oidc_module)
monkeypatch.setattr(
oidc_module.OIDCClient,
"parse_id_token",
lambda self, _id_token: {"picture": "id-picture", "email": "id@example.com"},
)
def _fake_parent_fetch(self, access_token, **_kwargs):
assert access_token == "access-1"
return oauth_module.UserInfo(
email="oauth@example.com",
username="oauth-user",
nickname="oauth-nick",
avatar_url=None,
)
monkeypatch.setattr(oauth_module.OAuthClient, "fetch_user_info", _fake_parent_fetch)
info = client.fetch_user_info("access-1", id_token="id-token")
assert info.email == "oauth@example.com"
assert info.username == "oauth-user"
assert info.nickname == "oauth-nick"
assert info.avatar_url == "id-picture"
@pytest.mark.p2
def test_async_fetch_user_info_merges_id_token_and_oauth_userinfo(monkeypatch):
oauth_module, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(monkeypatch, oidc_module)
monkeypatch.setattr(
oidc_module.OIDCClient,
"parse_id_token",
lambda self, _id_token: {"picture": "id-picture-async", "email": "id-async@example.com"},
)
async def _fake_parent_async_fetch(self, access_token, **_kwargs):
assert access_token == "access-2"
return oauth_module.UserInfo(
email="oauth-async@example.com",
username="oauth-async-user",
nickname="oauth-async-nick",
avatar_url=None,
)
monkeypatch.setattr(oauth_module.OAuthClient, "async_fetch_user_info", _fake_parent_async_fetch)
info = asyncio.run(client.async_fetch_user_info("access-2", id_token="id-token"))
assert info.email == "oauth-async@example.com"
assert info.username == "oauth-async-user"
assert info.nickname == "oauth-async-nick"
assert info.avatar_url == "id-picture-async"
@pytest.mark.p2
def test_normalize_user_info_passthrough(monkeypatch):
oauth_module, oidc_module = _load_auth_modules(monkeypatch)
client = _make_client(monkeypatch, oidc_module)
result = client.normalize_user_info(
{
"email": "user@example.com",
"username": "user",
"nickname": "User",
"picture": "picture-url",
}
)
assert isinstance(result, oauth_module.UserInfo)
assert result.to_dict() == {
"email": "user@example.com",
"username": "user",
"nickname": "User",
"avatar_url": "picture-url",
}
@pytest.mark.p2
def test_get_auth_client_type_inference_and_unsupported(monkeypatch):
auth_module = _load_auth_init_module(monkeypatch)
class _FakeOAuth2Client:
def __init__(self, config):
self.config = config
class _FakeOidcClient:
def __init__(self, config):
self.config = config
class _FakeGithubClient:
def __init__(self, config):
self.config = config
monkeypatch.setattr(
auth_module,
"CLIENT_TYPES",
{
"oauth2": _FakeOAuth2Client,
"oidc": _FakeOidcClient,
"github": _FakeGithubClient,
},
)
oidc_client = auth_module.get_auth_client({"issuer": "https://issuer.example"})
assert isinstance(oidc_client, _FakeOidcClient)
oauth_client = auth_module.get_auth_client({})
assert isinstance(oauth_client, _FakeOAuth2Client)
with pytest.raises(ValueError, match="Unsupported type: invalid"):
auth_module.get_auth_client({"type": "invalid"})
@pytest.mark.p2
def test_github_oauth_client_init_and_normalize_unit(monkeypatch):
github_module = _load_github_module(monkeypatch)
client = github_module.GithubOAuthClient(_base_config())
assert client.authorization_url == "https://github.com/login/oauth/authorize"
assert client.token_url == "https://github.com/login/oauth/access_token"
assert client.userinfo_url == "https://api.github.com/user"
assert client.scope == "user:email"
normalized = client.normalize_user_info(
{
"email": "octo@example.com",
"login": "octocat",
"name": "Octo Cat",
"avatar_url": "https://avatar.example/octocat.png",
}
)
assert normalized.to_dict() == {
"email": "octo@example.com",
"username": "octocat",
"nickname": "Octo Cat",
"avatar_url": "https://avatar.example/octocat.png",
}
normalized_fallback = client.normalize_user_info({"email": "fallback@example.com"})
assert normalized_fallback.to_dict() == {
"email": "fallback@example.com",
"username": "fallback",
"nickname": "fallback",
"avatar_url": "",
}
@pytest.mark.p2
def test_github_fetch_user_info_sync_success_and_error_unit(monkeypatch):
github_module = _load_github_module(monkeypatch)
client = github_module.GithubOAuthClient(_base_config())
calls = []
def _fake_sync_request(method, url, headers=None, timeout=None):
calls.append((method, url, headers, timeout))
if url.endswith("/emails"):
return _FakeResponse(
[
{"email": "other@example.com", "primary": False},
{"email": "octo@example.com", "primary": True},
]
)
return _FakeResponse({"login": "octocat", "name": "Octo Cat", "avatar_url": "https://avatar.example/octocat.png"})
monkeypatch.setattr(github_module, "sync_request", _fake_sync_request)
info = client.fetch_user_info("sync-token")
assert info.to_dict() == {
"email": "octo@example.com",
"username": "octocat",
"nickname": "Octo Cat",
"avatar_url": "https://avatar.example/octocat.png",
}
assert [call[1] for call in calls] == [
"https://api.github.com/user",
"https://api.github.com/user/emails",
]
assert all(call[2]["Authorization"] == "Bearer sync-token" for call in calls)
assert all(call[3] == 7 for call in calls)
def _sync_request_raises(*_args, **_kwargs):
return _FakeResponse(err=RuntimeError("status boom"))
monkeypatch.setattr(github_module, "sync_request", _sync_request_raises)
with pytest.raises(ValueError, match="Failed to fetch github user info: status boom"):
client.fetch_user_info("sync-token")
@pytest.mark.p2
def test_github_fetch_user_info_async_success_and_error_unit(monkeypatch):
github_module = _load_github_module(monkeypatch)
client = github_module.GithubOAuthClient(_base_config())
calls = []
async def _fake_async_request(method, url, headers=None, **kwargs):
calls.append((method, url, headers, kwargs.get("timeout")))
if url.endswith("/emails"):
return _FakeResponse(
[
{"email": "other@example.com", "primary": False},
{"email": "octo-async@example.com", "primary": True},
]
)
return _FakeResponse(
{"login": "octocat-async", "name": "Octo Async", "avatar_url": "https://avatar.example/octo-async.png"}
)
monkeypatch.setattr(github_module, "async_request", _fake_async_request)
info = asyncio.run(client.async_fetch_user_info("async-token"))
assert info.to_dict() == {
"email": "octo-async@example.com",
"username": "octocat-async",
"nickname": "Octo Async",
"avatar_url": "https://avatar.example/octo-async.png",
}
assert [call[1] for call in calls] == [
"https://api.github.com/user",
"https://api.github.com/user/emails",
]
assert all(call[2]["Authorization"] == "Bearer async-token" for call in calls)
assert all(call[3] == 7 for call in calls)
async def _async_request_raises(*_args, **_kwargs):
return _FakeResponse(err=RuntimeError("async status boom"))
monkeypatch.setattr(github_module, "async_request", _async_request_raises)
with pytest.raises(ValueError, match="Failed to fetch github user info: async status boom"):
asyncio.run(client.async_fetch_user_info("async-token"))

View File

@ -0,0 +1,669 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import base64
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _Vec(list):
def __mul__(self, scalar):
return _Vec([scalar * x for x in self])
__rmul__ = __mul__
def __add__(self, other):
return _Vec([a + b for a, b in zip(self, other)])
def tolist(self):
return list(self)
class _DummyDoc:
def __init__(self, *, doc_id="doc-1", kb_id="kb-1", name="Doc", parser_id="naive"):
self.id = doc_id
self.kb_id = kb_id
self.name = name
self.parser_id = parser_id
def to_dict(self):
return {"id": self.id, "kb_id": self.kb_id, "name": self.name}
class _DummyRetCode:
SUCCESS = 0
DATA_ERROR = 102
EXCEPTION_ERROR = 100
class _DummyParserType:
QA = "qa"
NAIVE = "naive"
class _DummyRetriever:
async def search(self, query, _index_name, _kb_ids, highlight=None):
class _SRes:
total = 1
ids = ["chunk-1"]
field = {
"chunk-1": {
"content_with_weight": "chunk content",
"doc_id": "doc-1",
"docnm_kwd": "Doc",
"important_kwd": ["k1"],
"question_kwd": ["q1"],
"img_id": "img-1",
"available_int": 1,
"position_int": [],
"doc_type_kwd": "text",
}
}
highlight = {"chunk-1": " highlighted content "}
_ = (query, highlight)
return _SRes()
class _DummyDocStore:
def __init__(self):
self.updated = []
self.inserted = []
self.deleted_inputs = []
self.to_delete = [1]
self.chunk = {
"id": "chunk-1",
"doc_id": "doc-1",
"kb_id": "kb-1",
"content_with_weight": "chunk content",
"docnm_kwd": "Doc",
"q_2_vec": [0.1, 0.2],
"content_tks": ["a"],
"content_ltks": ["b"],
"content_sm_ltks": ["c"],
}
def get(self, *_args, **_kwargs):
return dict(self.chunk) if self.chunk is not None else None
def update(self, condition, payload, *_args, **_kwargs):
self.updated.append((condition, payload))
return True
def delete(self, condition, *_args, **_kwargs):
self.deleted_inputs.append(condition)
if not self.to_delete:
return 0
return self.to_delete.pop(0)
def insert(self, docs, *_args, **_kwargs):
self.inserted.extend(docs)
class _DummyStorage:
def __init__(self):
self.put_calls = []
self.rm_calls = []
def put(self, bucket, name, binary):
self.put_calls.append((bucket, name, binary))
def obj_exist(self, _bucket, _name):
return True
def rm(self, bucket, name):
self.rm_calls.append((bucket, name))
class _DummyTenant:
def __init__(self, tenant_id="tenant-1"):
self.tenant_id = tenant_id
class _DummyLLMBundle:
def __init__(self, *_args, **_kwargs):
pass
def encode(self, _inputs):
return [_Vec([1.0, 2.0]), _Vec([3.0, 4.0])], 9
class _DummyXXHash:
def __init__(self, data):
self._data = data
def hexdigest(self):
return f"chunk-{len(self._data)}"
def _run(coro):
return asyncio.run(coro)
def _load_chunk_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
quart_mod = ModuleType("quart")
quart_mod.request = SimpleNamespace(args={}, headers={})
monkeypatch.setitem(sys.modules, "quart", quart_mod)
xxhash_mod = ModuleType("xxhash")
xxhash_mod.xxh64 = lambda data: _DummyXXHash(data)
monkeypatch.setitem(sys.modules, "xxhash", xxhash_mod)
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
settings_mod = ModuleType("common.settings")
settings_mod.retriever = _DummyRetriever()
settings_mod.docStoreConn = _DummyDocStore()
settings_mod.STORAGE_IMPL = _DummyStorage()
monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
common_pkg.settings = settings_mod
constants_mod = ModuleType("common.constants")
class _DummyLLMType:
EMBEDDING = SimpleNamespace(value="embedding")
CHAT = SimpleNamespace(value="chat")
constants_mod.RetCode = _DummyRetCode
constants_mod.LLMType = _DummyLLMType
constants_mod.ParserType = _DummyParserType
constants_mod.PAGERANK_FLD = "pagerank_flt"
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
string_utils_mod = ModuleType("common.string_utils")
string_utils_mod.remove_redundant_spaces = lambda text: " ".join(str(text).split())
monkeypatch.setitem(sys.modules, "common.string_utils", string_utils_mod)
metadata_utils_mod = ModuleType("common.metadata_utils")
metadata_utils_mod.apply_meta_data_filter = lambda *_args, **_kwargs: {}
monkeypatch.setitem(sys.modules, "common.metadata_utils", metadata_utils_mod)
misc_utils_mod = ModuleType("common.misc_utils")
async def _thread_pool_exec(func):
return func()
misc_utils_mod.thread_pool_exec = _thread_pool_exec
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
rag_pkg = ModuleType("rag")
rag_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
rag_app_pkg = ModuleType("rag.app")
rag_app_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "rag.app", rag_app_pkg)
rag_qa_mod = ModuleType("rag.app.qa")
rag_qa_mod.rmPrefix = lambda text: str(text).strip("Q: ").strip("A: ")
rag_qa_mod.beAdoc = lambda d, q, a, _latin: {**d, "question_kwd": [q], "content_with_weight": f"{q}\n{a}"}
monkeypatch.setitem(sys.modules, "rag.app.qa", rag_qa_mod)
rag_tag_mod = ModuleType("rag.app.tag")
rag_tag_mod.label_question = lambda *_args, **_kwargs: []
monkeypatch.setitem(sys.modules, "rag.app.tag", rag_tag_mod)
rag_nlp_mod = ModuleType("rag.nlp")
rag_nlp_mod.rag_tokenizer = SimpleNamespace(
tokenize=lambda text: [str(text)],
fine_grained_tokenize=lambda toks: [f"fg:{t}" for t in toks],
is_chinese=lambda _text: False,
)
rag_nlp_mod.search = SimpleNamespace(index_name=lambda tenant_id: f"idx-{tenant_id}")
monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp_mod)
rag_prompts_pkg = ModuleType("rag.prompts")
rag_prompts_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg)
rag_generator_mod = ModuleType("rag.prompts.generator")
rag_generator_mod.cross_languages = lambda *_args, **_kwargs: []
rag_generator_mod.keyword_extraction = lambda *_args, **_kwargs: []
monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_generator_mod)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="user-1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "message": message, "data": data}
api_utils_mod.get_data_error_result = lambda message="": {"code": _DummyRetCode.DATA_ERROR, "message": message, "data": False}
api_utils_mod.server_error_response = lambda exc: {"code": _DummyRetCode.EXCEPTION_ERROR, "message": repr(exc), "data": False}
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
api_utils_mod.get_request_json = lambda: _AwaitableValue({})
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
services_pkg = ModuleType("api.db.services")
services_pkg.__path__ = []
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
document_service_mod = ModuleType("api.db.services.document_service")
class _DocumentService:
decrement_calls = []
increment_calls = []
@staticmethod
def get_tenant_id(_doc_id):
return "tenant-1"
@staticmethod
def get_by_id(doc_id):
return True, _DummyDoc(doc_id=doc_id, parser_id=_DummyParserType.NAIVE)
@staticmethod
def get_embd_id(_doc_id):
return "embed-1"
@staticmethod
def decrement_chunk_num(*args):
_DocumentService.decrement_calls.append(args)
@staticmethod
def increment_chunk_num(*args):
_DocumentService.increment_calls.append(args)
document_service_mod.DocumentService = _DocumentService
monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
services_pkg.document_service = document_service_mod
doc_metadata_service_mod = ModuleType("api.db.services.doc_metadata_service")
doc_metadata_service_mod.DocMetadataService = type("DocMetadataService", (), {})
monkeypatch.setitem(sys.modules, "api.db.services.doc_metadata_service", doc_metadata_service_mod)
services_pkg.doc_metadata_service = doc_metadata_service_mod
kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
class _KnowledgebaseService:
@staticmethod
def get_kb_ids(_tenant_id):
return ["kb-1"]
@staticmethod
def get_by_id(_kb_id):
return True, SimpleNamespace(pagerank=0.6)
kb_service_mod.KnowledgebaseService = _KnowledgebaseService
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
services_pkg.knowledgebase_service = kb_service_mod
llm_service_mod = ModuleType("api.db.services.llm_service")
llm_service_mod.LLMBundle = _DummyLLMBundle
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
services_pkg.llm_service = llm_service_mod
search_service_mod = ModuleType("api.db.services.search_service")
search_service_mod.SearchService = type("SearchService", (), {})
monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod)
services_pkg.search_service = search_service_mod
user_service_mod = ModuleType("api.db.services.user_service")
class _UserTenantService:
@staticmethod
def query(**_kwargs):
return [_DummyTenant("tenant-1")]
user_service_mod.UserTenantService = _UserTenantService
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
services_pkg.user_service = user_service_mod
module_name = "test_chunk_routes_unit_module"
module_path = repo_root / "api" / "apps" / "chunk_app.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
def _set_request_json(monkeypatch, module, payload):
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))
@pytest.mark.p2
def test_list_chunk_exception_branches_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "keywords": "chunk", "available_int": 0})
res = _run(module.list_chunk())
assert res["code"] == 0, res
assert res["data"]["total"] == 1, res
assert res["data"]["chunks"][0]["available_int"] == 1, res
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "")
_set_request_json(monkeypatch, module, {"doc_id": "doc-1"})
res = _run(module.list_chunk())
assert res["code"] == module.RetCode.DATA_ERROR, res
assert res["message"] == "Tenant not found!", res
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
_set_request_json(monkeypatch, module, {"doc_id": "doc-1"})
res = _run(module.list_chunk())
assert res["message"] == "Document not found!", res
async def _raise_not_found(*_args, **_kwargs):
raise Exception("x not_found y")
monkeypatch.setattr(module.settings.retriever, "search", _raise_not_found)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, _DummyDoc()))
_set_request_json(monkeypatch, module, {"doc_id": "doc-1"})
res = _run(module.list_chunk())
assert res["code"] == module.RetCode.DATA_ERROR, res
assert res["message"] == "No chunk found!", res
async def _raise_generic(*_args, **_kwargs):
raise RuntimeError("boom")
monkeypatch.setattr(module.settings.retriever, "search", _raise_generic)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1"})
res = _run(module.list_chunk())
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
assert "boom" in res["message"], res
@pytest.mark.p2
def test_get_chunk_sanitize_and_exception_matrix_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
module.request = SimpleNamespace(args={"chunk_id": "chunk-1"}, headers={})
res = module.get()
assert res["code"] == 0, res
assert "q_2_vec" not in res["data"], res
assert "content_tks" not in res["data"], res
assert "content_ltks" not in res["data"], res
assert "content_sm_ltks" not in res["data"], res
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
res = module.get()
assert res["message"] == "Tenant not found!", res
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [_DummyTenant("tenant-1")])
module.settings.docStoreConn.chunk = None
res = module.get()
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
assert "Chunk not found" in res["message"], res
def _raise_not_found(*_args, **_kwargs):
raise Exception("NotFoundError: chunk-1")
monkeypatch.setattr(module.settings.docStoreConn, "get", _raise_not_found)
res = module.get()
assert res["code"] == module.RetCode.DATA_ERROR, res
assert res["message"] == "Chunk not found!", res
def _raise_generic(*_args, **_kwargs):
raise RuntimeError("get boom")
monkeypatch.setattr(module.settings.docStoreConn, "get", _raise_generic)
res = module.get()
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
assert "get boom" in res["message"], res
@pytest.mark.p2
def test_set_chunk_bytes_qa_image_and_guard_matrix_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": 1})
with pytest.raises(TypeError, match="expected string or bytes-like object"):
_run(module.set())
_set_request_json(
monkeypatch,
module,
{"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc", "important_kwd": "bad"},
)
res = _run(module.set())
assert res["message"] == "`important_kwd` should be a list", res
_set_request_json(
monkeypatch,
module,
{"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc", "question_kwd": "bad"},
)
res = _run(module.set())
assert res["message"] == "`question_kwd` should be a list", res
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "")
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc"})
res = _run(module.set())
assert res["message"] == "Tenant not found!", res
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc"})
res = _run(module.set())
assert res["message"] == "Document not found!", res
monkeypatch.setattr(
module.DocumentService,
"get_by_id",
lambda _doc_id: (True, _DummyDoc(doc_id="doc-1", parser_id=module.ParserType.NAIVE)),
)
_set_request_json(
monkeypatch,
module,
{
"doc_id": "doc-1",
"chunk_id": "chunk-1",
"content_with_weight": b"bytes-content",
"important_kwd": ["important"],
"question_kwd": ["question"],
"tag_kwd": ["tag"],
"tag_feas": [0.1],
"available_int": 0,
},
)
res = _run(module.set())
assert res["code"] == 0, res
assert module.settings.docStoreConn.updated[-1][1]["content_with_weight"] == "bytes-content"
monkeypatch.setattr(
module.DocumentService,
"get_by_id",
lambda _doc_id: (True, _DummyDoc(doc_id="doc-1", parser_id=module.ParserType.QA)),
)
_set_request_json(
monkeypatch,
module,
{
"doc_id": "doc-1",
"chunk_id": "chunk-2",
"content_with_weight": "Q:Question\nA:Answer",
"image_base64": base64.b64encode(b"image").decode("utf-8"),
"img_id": "bucket-name",
},
)
res = _run(module.set())
assert res["code"] == 0, res
assert module.settings.STORAGE_IMPL.put_calls, "image storage branch should be called"
async def _raise_thread_pool(_func):
raise RuntimeError("set tp boom")
monkeypatch.setattr(module, "thread_pool_exec", _raise_thread_pool)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc"})
res = _run(module.set())
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
assert "set tp boom" in res["message"], res
@pytest.mark.p2
def test_switch_chunk_success_failure_and_exception_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"], "available_int": 1})
res = _run(module.switch())
assert res["message"] == "Document not found!", res
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, _DummyDoc()))
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
monkeypatch.setattr(module.settings.docStoreConn, "update", lambda *_args, **_kwargs: False)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1", "c2"], "available_int": 0})
res = _run(module.switch())
assert res["message"] == "Index updating failure", res
monkeypatch.setattr(module.settings.docStoreConn, "update", lambda *_args, **_kwargs: True)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1", "c2"], "available_int": 1})
res = _run(module.switch())
assert res["code"] == 0, res
assert res["data"] is True, res
async def _raise_thread_pool(_func):
raise RuntimeError("switch tp boom")
monkeypatch.setattr(module, "thread_pool_exec", _raise_thread_pool)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"], "available_int": 1})
res = _run(module.switch())
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
assert "switch tp boom" in res["message"], res
@pytest.mark.p2
def test_rm_chunk_delete_exception_partial_compensation_and_cleanup_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"]})
res = _run(module.rm())
assert res["message"] == "Document not found!", res
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, _DummyDoc()))
def _raise_delete(*_args, **_kwargs):
raise RuntimeError("delete boom")
monkeypatch.setattr(module.settings.docStoreConn, "delete", _raise_delete)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"]})
res = _run(module.rm())
assert res["message"] == "Chunk deleting failure", res
def _delete(condition, *_args, **_kwargs):
module.settings.docStoreConn.deleted_inputs.append(condition)
if not module.settings.docStoreConn.to_delete:
return 0
return module.settings.docStoreConn.to_delete.pop(0)
module.settings.docStoreConn.to_delete = [0]
monkeypatch.setattr(module.settings.docStoreConn, "delete", _delete)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"]})
res = _run(module.rm())
assert res["message"] == "Index updating failure", res
module.settings.docStoreConn.to_delete = [1, 2]
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1", "c2", "c3"]})
res = _run(module.rm())
assert res["code"] == 0, res
assert module.DocumentService.decrement_calls, "decrement_chunk_num should be called"
assert len(module.settings.STORAGE_IMPL.rm_calls) >= 1
module.settings.docStoreConn.to_delete = [1]
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": "c1"})
res = _run(module.rm())
assert res["code"] == 0, res
async def _raise_thread_pool(_func):
raise RuntimeError("rm tp boom")
monkeypatch.setattr(module, "thread_pool_exec", _raise_thread_pool)
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"]})
res = _run(module.rm())
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
assert "rm tp boom" in res["message"], res
@pytest.mark.p2
def test_create_chunk_guards_pagerank_and_success_unit(monkeypatch):
module = _load_chunk_module(monkeypatch)
module.request = SimpleNamespace(headers={"X-Request-ID": "req-1"}, args={})
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk", "important_kwd": "bad"})
res = _run(module.create())
assert res["message"] == "`important_kwd` is required to be a list", res
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk", "question_kwd": "bad"})
res = _run(module.create())
assert res["message"] == "`question_kwd` is required to be a list", res
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk"})
res = _run(module.create())
assert res["message"] == "Document not found!", res
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, _DummyDoc(doc_id="doc-1")))
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "")
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk"})
res = _run(module.create())
assert res["message"] == "Tenant not found!", res
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk"})
res = _run(module.create())
assert res["message"] == "Knowledgebase not found!", res
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(pagerank=0.8)))
_set_request_json(
monkeypatch,
module,
{
"doc_id": "doc-1",
"content_with_weight": "chunk",
"important_kwd": ["i1"],
"question_kwd": ["q1"],
"tag_feas": [0.2],
},
)
res = _run(module.create())
assert res["code"] == 0, res
assert module.settings.docStoreConn.inserted, "insert should be called"
inserted = module.settings.docStoreConn.inserted[-1]
assert "pagerank_flt" in inserted
assert module.DocumentService.increment_calls, "increment_chunk_num should be called"

View File

@ -148,6 +148,35 @@ class TestAddChunk:
else:
assert res["message"] == expected_message, res
@pytest.mark.p2
def test_get_chunk_not_found(self, WebApiAuth):
res = get_chunk(WebApiAuth, {"chunk_id": "missing_chunk_id"})
assert res["code"] != 0, res
assert "Chunk not found" in res["message"], res
@pytest.mark.p2
def test_create_chunk_with_tag_fields(self, WebApiAuth, add_document):
_, doc_id = add_document
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
if res["code"] == 0:
chunks_count = res["data"]["doc"]["chunk_num"]
else:
chunks_count = 0
payload = {
"doc_id": doc_id,
"content_with_weight": "chunk with tags",
"tag_feas": [0.1, 0.2],
"important_kwd": ["tag"],
"question_kwd": ["question"],
}
res = add_chunk(WebApiAuth, payload)
assert res["code"] == 0, res
assert res["data"]["chunk_id"], res
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0, res
assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res
@pytest.mark.p3
@pytest.mark.parametrize(
"doc_id, expected_code, expected_message",

View File

@ -17,7 +17,7 @@ import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from common import batch_add_chunks, list_chunks
from common import batch_add_chunks, list_chunks, update_chunk
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth
@ -88,6 +88,33 @@ class TestChunksList:
else:
assert res["message"] == expected_message, res
@pytest.mark.p2
def test_available_int_filter(self, WebApiAuth, add_chunks):
_, doc_id, chunk_ids = add_chunks
chunk_id = chunk_ids[0]
res = update_chunk(
WebApiAuth,
{"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content", "available_int": 0},
)
assert res["code"] == 0, res
from time import sleep
sleep(1)
res = list_chunks(WebApiAuth, {"doc_id": doc_id, "available_int": 0})
assert res["code"] == 0, res
assert len(res["data"]["chunks"]) >= 1, res
assert all(chunk["available_int"] == 0 for chunk in res["data"]["chunks"]), res
# Restore the class-scoped fixture state for subsequent keyword cases.
res = update_chunk(
WebApiAuth,
{"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "chunk test 0", "available_int": 1},
)
assert res["code"] == 0, res
sleep(1)
@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_page_size",

View File

@ -95,6 +95,30 @@ class TestChunksDeletion:
assert len(res["data"]["chunks"]) == 0, res
assert res["data"]["total"] == 0, res
@pytest.mark.p2
def test_delete_scalar_chunk_id_payload(self, WebApiAuth, add_chunks_func):
_, doc_id, chunk_ids = add_chunks_func
payload = {"chunk_ids": chunk_ids[0], "doc_id": doc_id}
res = delete_chunks(WebApiAuth, payload)
assert res["code"] == 0, res
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0, res
assert len(res["data"]["chunks"]) == 3, res
assert res["data"]["total"] == 3, res
@pytest.mark.p2
def test_delete_duplicate_ids_dedup_behavior(self, WebApiAuth, add_chunks_func):
_, doc_id, chunk_ids = add_chunks_func
payload = {"chunk_ids": [chunk_ids[0], chunk_ids[0]], "doc_id": doc_id}
res = delete_chunks(WebApiAuth, payload)
assert res["code"] == 0, res
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0, res
assert len(res["data"]["chunks"]) == 3, res
assert res["data"]["total"] == 3, res
@pytest.mark.p3
def test_concurrent_deletion(self, WebApiAuth, add_document):
count = 100

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from random import randint
@ -154,6 +155,32 @@ class TestUpdateChunk:
if chunk["chunk_id"] == chunk_id:
assert chunk["available_int"] == payload["available_int"]
@pytest.mark.p2
def test_update_chunk_qa_multiline_content(self, WebApiAuth, add_chunks):
_, doc_id, chunk_ids = add_chunks
payload = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "Question line\nAnswer line"}
res = update_chunk(WebApiAuth, payload)
assert res["code"] == 0, res
sleep(1)
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
assert res["code"] == 0, res
chunk = next(chunk for chunk in res["data"]["chunks"] if chunk["chunk_id"] == chunk_ids[0])
assert chunk["content_with_weight"] == payload["content_with_weight"], res
@pytest.mark.p2
def test_update_chunk_with_image_payload(self, WebApiAuth, add_chunks):
_, doc_id, chunk_ids = add_chunks
payload = {
"doc_id": doc_id,
"chunk_id": chunk_ids[0],
"content_with_weight": "content with image",
"image_base64": base64.b64encode(b"img").decode("utf-8"),
"img_id": "bucket-name",
}
res = update_chunk(WebApiAuth, payload)
assert res["code"] == 0, res
@pytest.mark.p3
@pytest.mark.parametrize(
"doc_id_param, expected_code, expected_message",

View File

@ -0,0 +1,219 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _DummyAtomic:
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc, _tb):
return False
class _FakeApiError(Exception):
pass
class _FakeLangfuseClient:
def __init__(self, *, auth_result=True, auth_exc=None, project_payload=None):
self._auth_result = auth_result
self._auth_exc = auth_exc
if project_payload is None:
project_payload = {"data": [{"id": "project-id", "name": "project-name"}]}
self.api = SimpleNamespace(
projects=SimpleNamespace(get=lambda: SimpleNamespace(dict=lambda: project_payload)),
core=SimpleNamespace(api_error=SimpleNamespace(ApiError=_FakeApiError)),
)
def auth_check(self):
if self._auth_exc is not None:
raise self._auth_exc
return self._auth_result
def _run(coro):
return asyncio.run(coro)
def _load_langfuse_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
stub_apps = ModuleType("api.apps")
stub_apps.current_user = SimpleNamespace(id="tenant-1")
stub_apps.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", stub_apps)
stub_langfuse = ModuleType("langfuse")
stub_langfuse.Langfuse = _FakeLangfuseClient
monkeypatch.setitem(sys.modules, "langfuse", stub_langfuse)
module_path = repo_root / "api" / "apps" / "langfuse_app.py"
spec = importlib.util.spec_from_file_location("test_langfuse_app_unit", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_set_api_key_missing_fields_and_invalid_auth(monkeypatch):
module = _load_langfuse_app(monkeypatch)
monkeypatch.setattr(module.DB, "atomic", lambda: _DummyAtomic())
async def missing_fields():
return {"secret_key": "", "public_key": "pub", "host": "http://host"}
monkeypatch.setattr(module, "get_request_json", missing_fields)
res = _run(module.set_api_key.__wrapped__())
assert res["code"] == 102
assert res["message"] == "Missing required fields"
async def invalid_auth():
return {"secret_key": "sec", "public_key": "pub", "host": "http://host"}
monkeypatch.setattr(module, "get_request_json", invalid_auth)
monkeypatch.setattr(module, "Langfuse", lambda **_kwargs: _FakeLangfuseClient(auth_result=False))
res = _run(module.set_api_key.__wrapped__())
assert res["code"] == 102
assert res["message"] == "Invalid Langfuse keys"
@pytest.mark.p2
def test_set_api_key_create_update_and_atomic_exception(monkeypatch):
module = _load_langfuse_app(monkeypatch)
monkeypatch.setattr(module.DB, "atomic", lambda: _DummyAtomic())
monkeypatch.setattr(module, "Langfuse", lambda **_kwargs: _FakeLangfuseClient(auth_result=True))
async def payload():
return {"secret_key": "sec", "public_key": "pub", "host": "http://host"}
monkeypatch.setattr(module, "get_request_json", payload)
calls = {"save": 0, "update": 0}
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: None)
monkeypatch.setattr(
module.TenantLangfuseService,
"save",
lambda **_kwargs: calls.__setitem__("save", calls["save"] + 1),
)
monkeypatch.setattr(
module.TenantLangfuseService,
"update_by_tenant",
lambda **_kwargs: calls.__setitem__("update", calls["update"] + 1),
)
res = _run(module.set_api_key.__wrapped__())
assert res["code"] == 0
assert calls["save"] == 1
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: {"id": "existing"})
res = _run(module.set_api_key.__wrapped__())
assert res["code"] == 0
assert calls["update"] == 1
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: None)
def raise_save(**_kwargs):
raise RuntimeError("save failed")
monkeypatch.setattr(module.TenantLangfuseService, "save", raise_save)
res = _run(module.set_api_key.__wrapped__())
assert res["code"] == 100
assert "save failed" in res["message"]
@pytest.mark.p2
def test_get_api_key_no_record_invalid_auth_api_error_generic_error_success(monkeypatch):
module = _load_langfuse_app(monkeypatch)
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant_with_info", lambda **_kwargs: None)
res = module.get_api_key.__wrapped__()
assert res["code"] == 0
assert res["message"] == "Have not record any Langfuse keys."
base_entry = {"secret_key": "sec", "public_key": "pub", "host": "http://host"}
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant_with_info", lambda **_kwargs: dict(base_entry))
monkeypatch.setattr(module, "Langfuse", lambda **_kwargs: _FakeLangfuseClient(auth_result=False))
res = module.get_api_key.__wrapped__()
assert res["code"] == 102
assert res["message"] == "Invalid Langfuse keys loaded"
monkeypatch.setattr(
module,
"Langfuse",
lambda **_kwargs: _FakeLangfuseClient(auth_exc=_FakeApiError("api exploded")),
)
res = module.get_api_key.__wrapped__()
assert res["code"] == 0
assert "Error from Langfuse" in res["message"]
monkeypatch.setattr(
module,
"Langfuse",
lambda **_kwargs: _FakeLangfuseClient(auth_exc=RuntimeError("generic exploded")),
)
res = module.get_api_key.__wrapped__()
assert res["code"] == 100
assert "generic exploded" in res["message"]
monkeypatch.setattr(module, "Langfuse", lambda **_kwargs: _FakeLangfuseClient(auth_result=True))
res = module.get_api_key.__wrapped__()
assert res["code"] == 0
assert res["data"]["project_id"] == "project-id"
assert res["data"]["project_name"] == "project-name"
@pytest.mark.p2
def test_delete_api_key_no_record_success_exception(monkeypatch):
module = _load_langfuse_app(monkeypatch)
monkeypatch.setattr(module.DB, "atomic", lambda: _DummyAtomic())
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: None)
res = module.delete_api_key.__wrapped__()
assert res["code"] == 0
assert res["message"] == "Have not record any Langfuse keys."
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: {"id": "entry"})
monkeypatch.setattr(module.TenantLangfuseService, "delete_model", lambda _entry: None)
res = module.delete_api_key.__wrapped__()
assert res["code"] == 0
assert res["data"] is True
def raise_delete(_entry):
raise RuntimeError("delete failed")
monkeypatch.setattr(module.TenantLangfuseService, "delete_model", raise_delete)
res = module.delete_api_key.__wrapped__()
assert res["code"] == 100
assert "delete failed" in res["message"]

View File

@ -0,0 +1,584 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import sys
from copy import deepcopy
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
from anyio import Path as AsyncPath
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _DummyRequest:
def __init__(self, *, args=None, headers=None, form=None, files=None):
self.args = args or {}
self.headers = headers or {}
self.form = _AwaitableValue(form or {})
self.files = _AwaitableValue(files or {})
self.method = "POST"
self.content_length = 0
class _DummyConversation:
def __init__(self, *, conv_id="conv-1", dialog_id="dialog-1", message=None, reference=None):
self.id = conv_id
self.dialog_id = dialog_id
self.message = message if message is not None else []
self.reference = reference if reference is not None else []
def to_dict(self):
return {
"id": self.id,
"dialog_id": self.dialog_id,
"message": deepcopy(self.message),
"reference": deepcopy(self.reference),
}
class _DummyDialog:
def __init__(self, *, dialog_id="dialog-1", tenant_id="tenant-1", icon="avatar.png"):
self.id = dialog_id
self.tenant_id = tenant_id
self.icon = icon
self.prompt_config = {"prologue": "hello"}
self.llm_id = ""
self.llm_setting = {}
def to_dict(self):
return {
"id": self.id,
"icon": self.icon,
"tenant_id": self.tenant_id,
"prompt_config": deepcopy(self.prompt_config),
}
class _DummyUploadedFile:
def __init__(self, filename):
self.filename = filename
self.saved_path = None
async def save(self, path):
self.saved_path = path
await AsyncPath(path).write_bytes(b"audio-bytes")
def _run(coro):
return asyncio.run(coro)
def _load_conversation_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
deepdoc_pkg = ModuleType("deepdoc")
deepdoc_parser_pkg = ModuleType("deepdoc.parser")
deepdoc_parser_pkg.__path__ = []
class _StubPdfParser:
pass
class _StubExcelParser:
pass
class _StubDocxParser:
pass
deepdoc_parser_pkg.PdfParser = _StubPdfParser
deepdoc_parser_pkg.ExcelParser = _StubExcelParser
deepdoc_parser_pkg.DocxParser = _StubDocxParser
deepdoc_pkg.parser = deepdoc_parser_pkg
monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg)
monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg)
deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser")
deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser
monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module)
deepdoc_parser_utils = ModuleType("deepdoc.parser.utils")
deepdoc_parser_utils.get_text = lambda *_args, **_kwargs: ""
monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils)
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
apps_mod = ModuleType("api.apps")
apps_mod.current_user = SimpleNamespace(id="user-1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
module_name = "test_conversation_routes_unit_module"
module_path = repo_root / "api" / "apps" / "conversation_app.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
def _set_request_json(monkeypatch, module, payload):
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload)))
async def _read_sse_text(response):
chunks = []
async for chunk in response.response:
if isinstance(chunk, bytes):
chunks.append(chunk.decode("utf-8"))
else:
chunks.append(chunk)
return "".join(chunks)
@pytest.mark.p2
def test_set_conversation_update_create_and_errors(monkeypatch):
module = _load_conversation_module(monkeypatch)
long_name = "n" * 300
create_payload = {
"conversation_id": "conv-new",
"dialog_id": "dialog-1",
"is_new": True,
"name": long_name,
}
_set_request_json(monkeypatch, module, create_payload)
saved = {}
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialog()))
monkeypatch.setattr(module.ConversationService, "save", lambda **kwargs: saved.update(kwargs) or True)
res = _run(module.set_conversation())
assert res["code"] == 0
assert len(res["data"]["name"]) == 255
assert saved["user_id"] == "user-1"
update_payload = {
"conversation_id": "conv-1",
"dialog_id": "dialog-1",
"is_new": False,
"name": "rename",
}
_set_request_json(monkeypatch, module, update_payload)
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: False)
res = _run(module.set_conversation())
assert "Conversation not found" in res["message"]
_set_request_json(monkeypatch, module, update_payload)
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
res = _run(module.set_conversation())
assert "Fail to update" in res["message"]
_set_request_json(monkeypatch, module, update_payload)
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, _DummyConversation(conv_id="conv-1")))
res = _run(module.set_conversation())
assert res["code"] == 0
assert res["data"]["id"] == "conv-1"
_set_request_json(monkeypatch, module, update_payload)
def _raise_update(*_args, **_kwargs):
raise RuntimeError("update boom")
monkeypatch.setattr(module.ConversationService, "update_by_id", _raise_update)
res = _run(module.set_conversation())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "update boom" in res["message"]
missing_dialog_payload = {
"conversation_id": "conv-2",
"dialog_id": "dialog-missing",
"is_new": True,
"name": "create",
}
_set_request_json(monkeypatch, module, missing_dialog_payload)
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None))
res = _run(module.set_conversation())
assert res["message"] == "Dialog not found"
_set_request_json(monkeypatch, module, missing_dialog_payload)
def _raise_dialog(_id):
raise RuntimeError("dialog boom")
monkeypatch.setattr(module.DialogService, "get_by_id", _raise_dialog)
res = _run(module.set_conversation())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "dialog boom" in res["message"]
@pytest.mark.p2
def test_get_and_getsse_authorization_and_reference_paths(monkeypatch):
module = _load_conversation_module(monkeypatch)
conv = _DummyConversation(reference=[{"doc": "d"}, ["already-formatted"]])
monkeypatch.setattr(module, "request", _DummyRequest(args={"conversation_id": "conv-1"}))
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(icon="bot-avatar")])
monkeypatch.setattr(module, "chunks_format", lambda _ref: [{"chunk": "normalized"}])
res = _run(module.get())
assert res["code"] == 0
assert res["data"]["avatar"] == "bot-avatar"
assert res["data"]["reference"][0]["chunks"] == [{"chunk": "normalized"}]
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
res = _run(module.get())
assert res["message"] == "Conversation not found!"
monkeypatch.setattr(module, "request", _DummyRequest(args={"conversation_id": "conv-1"}))
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
res = _run(module.get())
assert res["code"] == module.RetCode.OPERATING_ERROR
assert "Only owner of conversation" in res["message"]
def _raise_get(*_args, **_kwargs):
raise RuntimeError("get boom")
monkeypatch.setattr(module.ConversationService, "get_by_id", _raise_get)
res = _run(module.get())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "get boom" in res["message"]
monkeypatch.setattr(module, "request", _DummyRequest(headers={"Authorization": "Bearer"}))
res = module.getsse("dialog-1")
assert "Authorization is not valid" in res["message"]
monkeypatch.setattr(module, "request", _DummyRequest(headers={"Authorization": "Bearer token-1"}))
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [])
res = module.getsse("dialog-1")
assert "API key is invalid" in res["message"]
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace()])
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None))
res = module.getsse("dialog-1")
assert res["message"] == "Dialog not found!"
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialog()))
res = module.getsse("dialog-1")
assert res["code"] == 0
assert res["data"]["avatar"] == "avatar.png"
assert "icon" not in res["data"]
def _raise_getsse(_id):
raise RuntimeError("getsse boom")
monkeypatch.setattr(module.DialogService, "get_by_id", _raise_getsse)
res = module.getsse("dialog-1")
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "getsse boom" in res["message"]
@pytest.mark.p2
def test_rm_and_list_conversation_guards(monkeypatch):
module = _load_conversation_module(monkeypatch)
_set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]})
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
res = _run(module.rm())
assert "Conversation not found" in res["message"]
conv = _DummyConversation(conv_id="conv-1", dialog_id="dialog-1")
_set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]})
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
res = _run(module.rm())
assert res["code"] == module.RetCode.OPERATING_ERROR
deleted = []
_set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]})
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="dialog-1")])
monkeypatch.setattr(module.ConversationService, "delete_by_id", lambda cid: deleted.append(cid) or True)
res = _run(module.rm())
assert res["code"] == 0
assert res["data"] is True
assert deleted == ["conv-1"]
_set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]})
def _raise_rm(*_args, **_kwargs):
raise RuntimeError("rm boom")
monkeypatch.setattr(module.ConversationService, "get_by_id", _raise_rm)
res = _run(module.rm())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "rm boom" in res["message"]
monkeypatch.setattr(module, "request", _DummyRequest(args={"dialog_id": "dialog-1"}))
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
res = _run(module.list_conversation())
assert res["code"] == module.RetCode.OPERATING_ERROR
assert "Only owner of dialog" in res["message"]
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="dialog-1")])
monkeypatch.setattr(module.ConversationService, "model", SimpleNamespace(create_time="create_time"))
monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: [_DummyConversation(conv_id="c1"), _DummyConversation(conv_id="c2")])
res = _run(module.list_conversation())
assert res["code"] == 0
assert [x["id"] for x in res["data"]] == ["c1", "c2"]
def _raise_list(**_kwargs):
raise RuntimeError("list boom")
monkeypatch.setattr(module.ConversationService, "query", _raise_list)
res = _run(module.list_conversation())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "list boom" in res["message"]
@pytest.mark.p2
def test_completion_stream_and_nonstream_branches(monkeypatch):
module = _load_conversation_module(monkeypatch)
conv = _DummyConversation(conv_id="conv-1", dialog_id="dialog-1", reference=[])
dia = _DummyDialog(dialog_id="dialog-1", tenant_id="tenant-1")
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, dia))
monkeypatch.setattr(module, "structure_answer", lambda _conv, ans, message_id, conv_id: {"answer": ans["answer"], "id": message_id, "conversation_id": conv_id, "reference": []})
updates = []
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda conv_id, payload: updates.append((conv_id, payload)) or True)
stream_payload = {
"conversation_id": "conv-1",
"messages": [
{"role": "system", "content": "ignored"},
{"role": "assistant", "content": "ignored-first-assistant"},
{"role": "user", "content": "hello", "id": "m-1"},
],
"stream": True,
}
async def _stream_ok(_dia, sanitized, *_args, **_kwargs):
assert [m["role"] for m in sanitized] == ["user"]
yield {"answer": "sse-ok"}
monkeypatch.setattr(module, "async_chat", _stream_ok)
_set_request_json(monkeypatch, module, stream_payload)
resp = _run(module.completion.__wrapped__())
assert resp.headers["Content-Type"].startswith("text/event-stream")
sse_text = _run(_read_sse_text(resp))
assert "sse-ok" in sse_text
assert '"data": true' in sse_text
assert updates
async def _stream_error(_dia, _sanitized, *_args, **_kwargs):
raise RuntimeError("stream explode")
if False:
yield {"answer": "never"}
monkeypatch.setattr(module, "async_chat", _stream_error)
_set_request_json(monkeypatch, module, stream_payload)
resp = _run(module.completion.__wrapped__())
sse_text = _run(_read_sse_text(resp))
assert "**ERROR**: stream explode" in sse_text
async def _non_stream(_dia, _sanitized, **_kwargs):
yield {"answer": "plain-ok"}
monkeypatch.setattr(module, "async_chat", _non_stream)
_set_request_json(
monkeypatch,
module,
{
"conversation_id": "conv-1",
"messages": [{"role": "user", "content": "plain", "id": "m-2"}],
"stream": False,
},
)
res = _run(module.completion.__wrapped__())
assert res["code"] == 0
assert res["data"]["answer"] == "plain-ok"
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda **_kwargs: False)
_set_request_json(
monkeypatch,
module,
{
"conversation_id": "conv-1",
"messages": [{"role": "user", "content": "embed", "id": "m-3"}],
"llm_id": "bad-model",
"stream": False,
},
)
res = _run(module.completion.__wrapped__())
assert "Cannot use specified model bad-model" in res["message"]
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda **_kwargs: "api-key")
_set_request_json(
monkeypatch,
module,
{
"conversation_id": "conv-1",
"messages": [{"role": "user", "content": "embed", "id": "m-4"}],
"llm_id": "glm-4",
"temperature": 0.7,
"top_p": 0.2,
"stream": False,
},
)
res = _run(module.completion.__wrapped__())
assert res["code"] == 0
assert dia.llm_id == "glm-4"
assert dia.llm_setting == {"temperature": 0.7, "top_p": 0.2}
_set_request_json(
monkeypatch,
module,
{
"conversation_id": "missing",
"messages": [{"role": "user", "content": "x", "id": "m-5"}],
"stream": False,
},
)
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
res = _run(module.completion.__wrapped__())
assert res["message"] == "Conversation not found!"
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None))
_set_request_json(
monkeypatch,
module,
{
"conversation_id": "conv-1",
"messages": [{"role": "user", "content": "x", "id": "m-6"}],
"stream": False,
},
)
res = _run(module.completion.__wrapped__())
assert res["message"] == "Dialog not found!"
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (_ for _ in ()).throw(RuntimeError("completion boom")))
_set_request_json(
monkeypatch,
module,
{
"conversation_id": "conv-1",
"messages": [{"role": "user", "content": "x", "id": "m-7"}],
"stream": False,
},
)
res = _run(module.completion.__wrapped__())
assert res["code"] == module.RetCode.EXCEPTION_ERROR
assert "completion boom" in res["message"]
@pytest.mark.p2
def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
module = _load_conversation_module(monkeypatch)
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={}))
res = _run(module.sequence2txt())
assert "Missing 'file'" in res["message"]
bad_file = _DummyUploadedFile("audio.txt")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": bad_file}))
res = _run(module.sequence2txt())
assert "Unsupported audio format" in res["message"]
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
res = _run(module.sequence2txt())
assert res["message"] == "Tenant not found!"
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "asr_id": ""}])
res = _run(module.sequence2txt())
assert res["message"] == "No default ASR model is set"
class _SyncAsr:
def transcription(self, _path):
return "transcribed text"
def stream_transcription(self, _path):
return []
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "asr_id": "asr-model"}])
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncAsr())
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("remove failed")))
res = _run(module.sequence2txt())
assert res["code"] == 0
assert res["data"]["text"] == "transcribed text"
class _StreamAsr:
def transcription(self, _path):
return ""
def stream_transcription(self, _path):
yield {"event": "partial", "text": "hello"}
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "true"}, files={"file": wav_file}))
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _StreamAsr())
resp = _run(module.sequence2txt())
assert resp.headers["Content-Type"].startswith("text/event-stream")
sse_text = _run(_read_sse_text(resp))
assert '"event": "partial"' in sse_text
class _ErrorStreamAsr:
def transcription(self, _path):
return ""
def stream_transcription(self, _path):
raise RuntimeError("stream asr boom")
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "true"}, files={"file": wav_file}))
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _ErrorStreamAsr())
resp = _run(module.sequence2txt())
sse_text = _run(_read_sse_text(resp))
assert "stream asr boom" in sse_text
@pytest.mark.p2
def test_tts_request_parse_entry(monkeypatch):
module = _load_conversation_module(monkeypatch)
_set_request_json(monkeypatch, module, {"text": "hello"})
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
res = _run(module.tts())
assert res["message"] == "Tenant not found!"

View File

@ -15,10 +15,22 @@
#
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
from common import bulk_upload_documents, delete_document, list_documents
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
@pytest.fixture(scope="function")
def add_document_func(request, WebApiAuth, add_dataset, ragflow_tmp_dir):
def cleanup():
@ -56,3 +68,49 @@ def add_documents_func(request, WebApiAuth, add_dataset_func, ragflow_tmp_dir):
dataset_id = add_dataset_func
return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 3, ragflow_tmp_dir)
@pytest.fixture()
def document_app_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
deepdoc_pkg = ModuleType("deepdoc")
deepdoc_parser_pkg = ModuleType("deepdoc.parser")
deepdoc_parser_pkg.__path__ = []
class _StubPdfParser:
pass
class _StubExcelParser:
pass
deepdoc_parser_pkg.PdfParser = _StubPdfParser
deepdoc_pkg.parser = deepdoc_parser_pkg
monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg)
monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg)
deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser")
deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser
monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module)
deepdoc_html_module = ModuleType("deepdoc.parser.html_parser")
class _StubHtmlParser:
pass
deepdoc_html_module.RAGFlowHtmlParser = _StubHtmlParser
monkeypatch.setitem(sys.modules, "deepdoc.parser.html_parser", deepdoc_html_module)
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
stub_apps = ModuleType("api.apps")
stub_apps.current_user = SimpleNamespace(id="user-1")
stub_apps.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", stub_apps)
module_path = repo_root / "api" / "apps" / "document_app.py"
spec = importlib.util.spec_from_file_location("test_document_app_unit", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
spec.loader.exec_module(module)
return module

View File

@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import string
from types import SimpleNamespace
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
@ -21,6 +23,7 @@ from common import create_document, list_kbs
from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth
from utils.file_utils import create_txt_file
from api.constants import FILE_NAME_LEN_LIMIT
@pytest.mark.p1
@ -90,3 +93,130 @@ class TestDocumentCreate:
res = list_kbs(WebApiAuth, {"id": kb_id})
assert res["data"]["kbs"][0]["doc_num"] == count, res
def _run(coro):
return asyncio.run(coro)
@pytest.mark.p2
class TestDocumentCreateUnit:
def test_missing_kb_id(self, document_app_module, monkeypatch):
module = document_app_module
async def fake_request_json():
return {"kb_id": "", "name": "doc.txt"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.create.__wrapped__())
assert res["code"] == 101
assert res["message"] == 'Lack of "KB ID"'
def test_filename_too_long(self, document_app_module, monkeypatch):
module = document_app_module
long_name = "a" * (FILE_NAME_LEN_LIMIT + 1)
async def fake_request_json():
return {"kb_id": "kb1", "name": long_name}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.create.__wrapped__())
assert res["code"] == 101
assert res["message"] == f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less."
def test_filename_whitespace(self, document_app_module, monkeypatch):
module = document_app_module
async def fake_request_json():
return {"kb_id": "kb1", "name": " "}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.create.__wrapped__())
assert res["code"] == 101
assert res["message"] == "File name can't be empty."
def test_kb_not_found(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
async def fake_request_json():
return {"kb_id": "missing", "name": "doc.txt"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.create.__wrapped__())
assert res["code"] == 102
assert res["message"] == "Can't find this dataset!"
def test_duplicate_name(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", pipeline_id="pipe", parser_config={})
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [object()])
async def fake_request_json():
return {"kb_id": "kb1", "name": "doc.txt"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.create.__wrapped__())
assert res["code"] == 102
assert "Duplicated document name" in res["message"]
def test_root_folder_missing(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", pipeline_id="pipe", parser_config={})
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: None)
async def fake_request_json():
return {"kb_id": "kb1", "name": "doc.txt"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.create.__wrapped__())
assert res["code"] == 102
assert res["message"] == "Cannot find the root folder."
def test_kb_folder_missing(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", pipeline_id="pipe", parser_config={})
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: {"id": "root"})
monkeypatch.setattr(module.FileService, "new_a_file_from_kb", lambda *_args, **_kwargs: None)
async def fake_request_json():
return {"kb_id": "kb1", "name": "doc.txt"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.create.__wrapped__())
assert res["code"] == 102
assert res["message"] == "Cannot find the kb folder for this file."
def test_success(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", pipeline_id="pipe", parser_config={})
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: {"id": "root"})
monkeypatch.setattr(module.FileService, "new_a_file_from_kb", lambda *_args, **_kwargs: {"id": "folder"})
class _Doc:
def __init__(self, doc_id):
self.id = doc_id
def to_json(self):
return {"id": self.id, "name": "doc.txt", "kb_id": "kb1"}
def to_dict(self):
return {"id": self.id, "name": "doc.txt", "kb_id": "kb1"}
monkeypatch.setattr(module.DocumentService, "insert", lambda _doc: _Doc("doc1"))
monkeypatch.setattr(module.FileService, "add_file_from_kb", lambda *_args, **_kwargs: None)
async def fake_request_json():
return {"kb_id": "kb1", "name": "doc.txt"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.create.__wrapped__())
assert res["code"] == 0
assert res["data"]["id"] == "doc1"

View File

@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from types import SimpleNamespace
import pytest
from common import (
document_change_status,
@ -241,3 +244,170 @@ class TestDocumentMetadataNegative:
res = document_set_meta(WebApiAuth, {"doc_id": doc_id, "meta": "[]"})
assert res["code"] == 101, res
assert "dictionary" in res["message"], res
def _run(coro):
return asyncio.run(coro)
@pytest.mark.p2
class TestDocumentMetadataUnit:
def _allow_kb(self, module, monkeypatch, kb_id="kb1", tenant_id="tenant1"):
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id=tenant_id)])
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: True if _kwargs.get("id") == kb_id else False)
def test_filter_missing_kb_id(self, document_app_module, monkeypatch):
module = document_app_module
async def fake_request_json():
return {}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.get_filter())
assert res["code"] == 101
assert "KB ID" in res["message"]
def test_filter_unauthorized(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant1")])
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False)
async def fake_request_json():
return {"kb_id": "kb1"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.get_filter())
assert res["code"] == 103
def test_filter_invalid_filters(self, document_app_module, monkeypatch):
module = document_app_module
self._allow_kb(module, monkeypatch)
async def fake_request_json():
return {"kb_id": "kb1", "run_status": ["INVALID"]}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.get_filter())
assert res["code"] == 102
assert "Invalid filter run status" in res["message"]
async def fake_request_json_types():
return {"kb_id": "kb1", "types": ["INVALID"]}
monkeypatch.setattr(module, "get_request_json", fake_request_json_types)
res = _run(module.get_filter())
assert res["code"] == 102
assert "Invalid filter conditions" in res["message"]
def test_filter_keywords_suffix(self, document_app_module, monkeypatch):
module = document_app_module
self._allow_kb(module, monkeypatch)
monkeypatch.setattr(module.DocumentService, "get_filter_by_kb_id", lambda *_args, **_kwargs: ({"run": {}}, 1))
async def fake_request_json():
return {"kb_id": "kb1", "keywords": "ragflow", "suffix": ["txt"]}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.get_filter())
assert res["code"] == 0
assert "filter" in res["data"]
def test_filter_exception(self, document_app_module, monkeypatch):
module = document_app_module
self._allow_kb(module, monkeypatch)
def raise_error(*_args, **_kwargs):
raise RuntimeError("boom")
monkeypatch.setattr(module.DocumentService, "get_filter_by_kb_id", raise_error)
async def fake_request_json():
return {"kb_id": "kb1"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.get_filter())
assert res["code"] == 100
def test_infos_meta_fields(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
class _Docs:
def dicts(self):
return [{"id": "doc1"}]
monkeypatch.setattr(module.DocumentService, "get_by_ids", lambda _ids: _Docs())
monkeypatch.setattr(module.DocMetadataService, "get_document_metadata", lambda _doc_id: {"author": "alice"})
async def fake_request_json():
return {"doc_ids": ["doc1"]}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.doc_infos())
assert res["code"] == 0
assert res["data"][0]["meta_fields"]["author"] == "alice"
def test_metadata_summary_missing_kb_id(self, document_app_module, monkeypatch):
module = document_app_module
async def fake_request_json():
return {"doc_ids": ["doc1"]}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.metadata_summary())
assert res["code"] == 101
def test_metadata_summary_unauthorized(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant1")])
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False)
async def fake_request_json():
return {"kb_id": "kb1", "doc_ids": ["doc1"]}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.metadata_summary())
assert res["code"] == 103
def test_metadata_summary_success_and_exception(self, document_app_module, monkeypatch):
module = document_app_module
self._allow_kb(module, monkeypatch)
monkeypatch.setattr(module.DocMetadataService, "get_metadata_summary", lambda *_args, **_kwargs: {"author": {"alice": 1}})
async def fake_request_json():
return {"kb_id": "kb1", "doc_ids": ["doc1"]}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.metadata_summary())
assert res["code"] == 0
assert "summary" in res["data"]
def raise_error(*_args, **_kwargs):
raise RuntimeError("boom")
monkeypatch.setattr(module.DocMetadataService, "get_metadata_summary", raise_error)
res = _run(module.metadata_summary())
assert res["code"] == 100
def test_metadata_update_missing_kb_id(self, document_app_module, monkeypatch):
module = document_app_module
async def fake_request_json():
return {"doc_ids": ["doc1"], "updates": [], "deletes": []}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.metadata_update.__wrapped__())
assert res["code"] == 101
assert "KB ID" in res["message"]
def test_metadata_update_success(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module.DocMetadataService, "batch_update_metadata", lambda *_args, **_kwargs: 1)
async def fake_request_json():
return {"kb_id": "kb1", "doc_ids": ["doc1"], "updates": [{"key": "author", "value": "alice"}], "deletes": []}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.metadata_update.__wrapped__())
assert res["code"] == 0
assert res["data"]["matched_docs"] == 1

View File

@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from concurrent.futures import ThreadPoolExecutor, as_completed
from types import SimpleNamespace
import pytest
from common import list_documents
@ -178,3 +180,214 @@ class TestDocumentsList:
responses = list(as_completed(futures))
assert len(responses) == count, responses
assert all(future.result()["code"] == 0 for future in futures), responses
def _run(coro):
return asyncio.run(coro)
class _DummyArgs(dict):
def get(self, key, default=None):
return super().get(key, default)
@pytest.mark.p2
class TestDocumentsListUnit:
def _set_args(self, module, monkeypatch, **kwargs):
monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs(kwargs)))
def _allow_kb(self, module, monkeypatch, kb_id="kb1", tenant_id="tenant1"):
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id=tenant_id)])
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: True if _kwargs.get("id") == kb_id else False)
def test_missing_kb_id(self, document_app_module, monkeypatch):
module = document_app_module
self._set_args(module, monkeypatch)
async def fake_request_json():
return {}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.list_docs())
assert res["code"] == 101
assert res["message"] == 'Lack of "KB ID"'
def test_unauthorized_dataset(self, document_app_module, monkeypatch):
module = document_app_module
self._set_args(module, monkeypatch, kb_id="kb1")
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant1")])
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False)
async def fake_request_json():
return {}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.list_docs())
assert res["code"] == 103
assert "Only owner of dataset" in res["message"]
def test_return_empty_metadata_flags(self, document_app_module, monkeypatch):
module = document_app_module
self._set_args(module, monkeypatch, kb_id="kb1")
self._allow_kb(module, monkeypatch)
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda *_args, **_kwargs: ([], 0))
async def fake_request_json():
return {"return_empty_metadata": "true", "metadata": {"author": "alice"}}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.list_docs())
assert res["code"] == 0
async def fake_request_json_empty():
return {"metadata": {"empty_metadata": True, "author": "alice"}}
monkeypatch.setattr(module, "get_request_json", fake_request_json_empty)
res = _run(module.list_docs())
assert res["code"] == 0
def test_invalid_filters(self, document_app_module, monkeypatch):
module = document_app_module
self._set_args(module, monkeypatch, kb_id="kb1")
self._allow_kb(module, monkeypatch)
async def fake_request_json():
return {"run_status": ["INVALID"]}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.list_docs())
assert res["code"] == 102
assert "Invalid filter run status" in res["message"]
async def fake_request_json_types():
return {"types": ["INVALID"]}
monkeypatch.setattr(module, "get_request_json", fake_request_json_types)
res = _run(module.list_docs())
assert res["code"] == 102
assert "Invalid filter conditions" in res["message"]
def test_invalid_metadata_types(self, document_app_module, monkeypatch):
module = document_app_module
self._set_args(module, monkeypatch, kb_id="kb1")
self._allow_kb(module, monkeypatch)
async def fake_request_json():
return {"metadata_condition": "bad"}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.list_docs())
assert res["code"] == 102
assert "metadata_condition" in res["message"]
async def fake_request_json_meta():
return {"metadata": ["not", "object"]}
monkeypatch.setattr(module, "get_request_json", fake_request_json_meta)
res = _run(module.list_docs())
assert res["code"] == 102
assert "metadata must be an object" in res["message"]
def test_metadata_condition_empty_result(self, document_app_module, monkeypatch):
module = document_app_module
self._set_args(module, monkeypatch, kb_id="kb1")
self._allow_kb(module, monkeypatch)
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda *_args, **_kwargs: {})
monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: set())
async def fake_request_json():
return {"metadata_condition": {"conditions": [{"name": "author", "comparison_operator": "is", "value": "alice"}]}}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.list_docs())
assert res["code"] == 0
assert res["data"]["total"] == 0
def test_metadata_values_intersection(self, document_app_module, monkeypatch):
module = document_app_module
self._set_args(module, monkeypatch, kb_id="kb1")
self._allow_kb(module, monkeypatch)
metas = {
"author": {"alice": ["doc1", "doc2"]},
"topic": {"rag": ["doc2"]},
}
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda *_args, **_kwargs: metas)
captured = {}
def fake_get_by_kb_id(*_args, **_kwargs):
if len(_args) >= 10:
captured["doc_ids_filter"] = _args[9]
else:
captured["doc_ids_filter"] = None
return ([{"id": "doc2", "thumbnail": "", "parser_config": {}}], 1)
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", fake_get_by_kb_id)
async def fake_request_json():
return {"metadata": {"author": ["alice", " ", None], "topic": "rag"}}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.list_docs())
assert res["code"] == 0
assert captured["doc_ids_filter"] == ["doc2"]
def test_metadata_intersection_empty(self, document_app_module, monkeypatch):
module = document_app_module
self._set_args(module, monkeypatch, kb_id="kb1")
self._allow_kb(module, monkeypatch)
metas = {
"author": {"alice": ["doc1"]},
"topic": {"rag": ["doc2"]},
}
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda *_args, **_kwargs: metas)
async def fake_request_json():
return {"metadata": {"author": "alice", "topic": "rag"}}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.list_docs())
assert res["code"] == 0
assert res["data"]["total"] == 0
def test_desc_time_and_schema(self, document_app_module, monkeypatch):
module = document_app_module
self._set_args(module, monkeypatch, kb_id="kb1", desc="false", create_time_from="150", create_time_to="250")
self._allow_kb(module, monkeypatch)
docs = [
{"id": "doc1", "thumbnail": "", "parser_config": {"metadata": {"a": 1}}, "create_time": 100},
{"id": "doc2", "thumbnail": "", "parser_config": {"metadata": {"b": 2}}, "create_time": 200},
]
def fake_get_by_kb_id(*_args, **_kwargs):
return (docs, 2)
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", fake_get_by_kb_id)
monkeypatch.setattr(module, "turn2jsonschema", lambda _meta: {"schema": True})
async def fake_request_json():
return {}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.list_docs())
assert res["code"] == 0
assert len(res["data"]["docs"]) == 1
assert res["data"]["docs"][0]["parser_config"]["metadata"] == {"schema": True}
def test_exception_path(self, document_app_module, monkeypatch):
module = document_app_module
self._set_args(module, monkeypatch, kb_id="kb1")
self._allow_kb(module, monkeypatch)
def raise_error(*_args, **_kwargs):
raise RuntimeError("boom")
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", raise_error)
async def fake_request_json():
return {}
monkeypatch.setattr(module, "get_request_json", fake_request_json)
res = _run(module.list_docs())
assert res["code"] == 100

View File

@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import string
from types import SimpleNamespace
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
@ -21,6 +23,7 @@ from common import list_kbs, upload_documents
from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth
from utils.file_utils import create_txt_file
from api.constants import FILE_NAME_LEN_LIMIT
@pytest.mark.p1
@ -189,3 +192,288 @@ class TestDocumentsUpload:
res = list_kbs(WebApiAuth)
assert res["data"]["kbs"][0]["doc_num"] == count, res
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _coro():
return self._value
return _coro().__await__()
class _DummyFiles(dict):
def getlist(self, key):
value = self.get(key, [])
if isinstance(value, list):
return value
return [value]
class _DummyFile:
def __init__(self, filename):
self.filename = filename
self.closed = False
self.stream = self
def close(self):
self.closed = True
class _DummyRequest:
def __init__(self, form=None, files=None):
self._form = form or {}
self._files = files or _DummyFiles()
@property
def form(self):
return _AwaitableValue(self._form)
@property
def files(self):
return _AwaitableValue(self._files)
def _run(coro):
return asyncio.run(coro)
@pytest.mark.p2
class TestDocumentsUploadUnit:
def test_missing_kb_id(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": ""}, files=_DummyFiles()))
res = _run(module.upload.__wrapped__())
assert res["code"] == 101
assert res["message"] == 'Lack of "KB ID"'
def test_missing_file_part(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=_DummyFiles()))
res = _run(module.upload.__wrapped__())
assert res["code"] == 101
assert res["message"] == "No file part!"
def test_empty_filename_closes_files(self, document_app_module, monkeypatch):
module = document_app_module
file_obj = _DummyFile("")
files = _DummyFiles({"file": [file_obj]})
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=files))
res = _run(module.upload.__wrapped__())
assert res["code"] == 101
assert res["message"] == "No file selected!"
assert file_obj.closed is True
def test_filename_too_long(self, document_app_module, monkeypatch):
module = document_app_module
long_name = "a" * (FILE_NAME_LEN_LIMIT + 1)
file_obj = _DummyFile(long_name)
files = _DummyFiles({"file": [file_obj]})
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=files))
res = _run(module.upload.__wrapped__())
assert res["code"] == 101
assert res["message"] == f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less."
def test_invalid_kb_id_raises(self, document_app_module, monkeypatch):
module = document_app_module
file_obj = _DummyFile("ragflow_test.txt")
files = _DummyFiles({"file": [file_obj]})
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "missing"}, files=files))
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
with pytest.raises(LookupError):
_run(module.upload.__wrapped__())
def test_no_permission(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: False)
file_obj = _DummyFile("ragflow_test.txt")
files = _DummyFiles({"file": [file_obj]})
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=files))
res = _run(module.upload.__wrapped__())
assert res["code"] == 109
assert res["message"] == "No authorization."
def test_thread_pool_errors(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
async def fake_thread_pool_exec(*_args, **_kwargs):
return (["unsupported type"], [("file1", "blob")])
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
file_obj = _DummyFile("ragflow_test.txt")
files = _DummyFiles({"file": [file_obj]})
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=files))
res = _run(module.upload.__wrapped__())
assert res["code"] == 500
assert "unsupported type" in res["message"]
assert res["data"] == ["file1"]
def test_empty_upload_result(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
async def fake_thread_pool_exec(*_args, **_kwargs):
return (None, [])
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
file_obj = _DummyFile("ragflow_test.txt")
files = _DummyFiles({"file": [file_obj]})
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=files))
res = _run(module.upload.__wrapped__())
assert res["code"] == 102
assert "file format" in res["message"]
@pytest.mark.p2
class TestWebCrawlUnit:
def test_missing_kb_id(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "", "name": "doc", "url": "http://example.com"}))
res = _run(module.web_crawl.__wrapped__())
assert res["code"] == 101
assert res["message"] == 'Lack of "KB ID"'
def test_invalid_url(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "not-a-url"}))
res = _run(module.web_crawl.__wrapped__())
assert res["code"] == 101
assert res["message"] == "The URL format is invalid"
def test_invalid_kb_id_raises(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "missing", "name": "doc", "url": "http://example.com"}))
with pytest.raises(LookupError):
_run(module.web_crawl.__wrapped__())
def test_no_permission(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: False)
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "http://example.com"}))
res = _run(module.web_crawl.__wrapped__())
assert res["code"] == 109
assert res["message"] == "No authorization."
def test_download_failure(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module, "html2pdf", lambda _url: None)
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "http://example.com"}))
res = _run(module.web_crawl.__wrapped__())
assert res["code"] == 100
assert "Download failure" in res["message"]
def test_unsupported_type(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module, "html2pdf", lambda _url: b"%PDF-1.4")
monkeypatch.setattr(module.FileService, "get_root_folder", lambda _uid: {"id": "root"})
monkeypatch.setattr(module.FileService, "init_knowledgebase_docs", lambda *_args, **_kwargs: None)
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: {"id": "kb_root"})
monkeypatch.setattr(module.FileService, "new_a_file_from_kb", lambda *_args, **_kwargs: {"id": "kb_folder"})
monkeypatch.setattr(module, "duplicate_name", lambda *_args, **_kwargs: "bad.exe")
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "http://example.com"}))
res = _run(module.web_crawl.__wrapped__())
assert res["code"] == 100
assert "supported yet" in res["message"]
@pytest.mark.parametrize(
"filename,filetype,expected_parser",
[
("image.png", "visual", "picture"),
("sound.mp3", "aural", "audio"),
("deck.pptx", "doc", "presentation"),
("mail.eml", "doc", "email"),
],
)
def test_success_parser_overrides(self, document_app_module, monkeypatch, filename, filetype, expected_parser):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
captured = {}
class _Storage:
def obj_exist(self, *_args, **_kwargs):
return False
def put(self, *_args, **_kwargs):
captured["put"] = True
def insert_doc(doc):
captured["doc"] = doc
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module, "html2pdf", lambda _url: b"%PDF-1.4")
monkeypatch.setattr(module.FileService, "get_root_folder", lambda _uid: {"id": "root"})
monkeypatch.setattr(module.FileService, "init_knowledgebase_docs", lambda *_args, **_kwargs: None)
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: {"id": "kb_root"})
monkeypatch.setattr(module.FileService, "new_a_file_from_kb", lambda *_args, **_kwargs: {"id": "kb_folder"})
monkeypatch.setattr(module, "duplicate_name", lambda *_args, **_kwargs: filename)
monkeypatch.setattr(module, "filename_type", lambda _name: filetype)
monkeypatch.setattr(module, "thumbnail", lambda *_args, **_kwargs: "")
monkeypatch.setattr(module, "get_uuid", lambda: "doc-1")
monkeypatch.setattr(module.settings, "STORAGE_IMPL", _Storage())
monkeypatch.setattr(module.DocumentService, "insert", insert_doc)
monkeypatch.setattr(module.FileService, "add_file_from_kb", lambda *_args, **_kwargs: None)
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "http://example.com"}))
res = _run(module.web_crawl.__wrapped__())
assert res["code"] == 0
assert captured["doc"]["parser_id"] == expected_parser
assert captured["put"] is True
def test_exception_path(self, document_app_module, monkeypatch):
module = document_app_module
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
class _Storage:
def obj_exist(self, *_args, **_kwargs):
return False
def put(self, *_args, **_kwargs):
return None
def insert_doc(_doc):
raise RuntimeError("boom")
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
monkeypatch.setattr(module, "html2pdf", lambda _url: b"%PDF-1.4")
monkeypatch.setattr(module.FileService, "get_root_folder", lambda _uid: {"id": "root"})
monkeypatch.setattr(module.FileService, "init_knowledgebase_docs", lambda *_args, **_kwargs: None)
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: {"id": "kb_root"})
monkeypatch.setattr(module.FileService, "new_a_file_from_kb", lambda *_args, **_kwargs: {"id": "kb_folder"})
monkeypatch.setattr(module, "duplicate_name", lambda *_args, **_kwargs: "doc.pdf")
monkeypatch.setattr(module, "filename_type", lambda _name: "pdf")
monkeypatch.setattr(module, "thumbnail", lambda *_args, **_kwargs: "")
monkeypatch.setattr(module, "get_uuid", lambda: "doc-1")
monkeypatch.setattr(module.settings, "STORAGE_IMPL", _Storage())
monkeypatch.setattr(module.DocumentService, "insert", insert_doc)
monkeypatch.setattr(module.FileService, "add_file_from_kb", lambda *_args, **_kwargs: None)
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "http://example.com"}))
res = _run(module.web_crawl.__wrapped__())
assert res["code"] == 100

View File

@ -206,3 +206,25 @@ class TestKbPipelineLogs:
res = kb_delete_pipeline_logs(WebApiAuth, params={"kb_id": kb_id}, payload={"log_ids": []})
assert res["code"] == 0, res
assert res["data"] is True, res
@pytest.mark.p3
def test_list_pipeline_logs_missing_kb_id(self, WebApiAuth):
res = kb_list_pipeline_logs(WebApiAuth, params={}, payload={})
assert res["code"] == 101, res
assert "KB ID" in res["message"], res
@pytest.mark.p3
def test_list_pipeline_logs_abnormal_date_filter(self, WebApiAuth, add_document):
kb_id, _ = add_document
res = kb_list_pipeline_logs(
WebApiAuth,
params={
"kb_id": kb_id,
"desc": "false",
"create_date_from": "2025-01-01",
"create_date_to": "2025-02-01",
},
payload={},
)
assert res["code"] == 102, res
assert "Create data filter is abnormal." in res["message"], res

File diff suppressed because it is too large Load Diff

View File

@ -17,9 +17,11 @@ import uuid
import pytest
from common import (
delete_knowledge_graph,
kb_basic_info,
kb_get_meta,
kb_update_metadata_setting,
knowledge_graph,
list_tags,
list_tags_from_kbs,
rename_tags,
@ -121,6 +123,20 @@ class TestAuthorization:
assert res["code"] == expected_code, res
assert expected_fragment in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES)
def test_knowledge_graph_auth_invalid(self, invalid_auth, expected_code, expected_fragment):
res = knowledge_graph(invalid_auth, "kb_id")
assert res["code"] == expected_code, res
assert expected_fragment in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES)
def test_delete_knowledge_graph_auth_invalid(self, invalid_auth, expected_code, expected_fragment):
res = delete_knowledge_graph(invalid_auth, "kb_id")
assert res["code"] == expected_code, res
assert expected_fragment in res["message"], res
class TestKbTagsMeta:
@pytest.mark.p2
@ -205,6 +221,22 @@ class TestKbTagsMeta:
assert res["data"]["id"] == kb_id, res
assert res["data"]["parser_config"]["metadata"] == metadata, res
@pytest.mark.p2
def test_knowledge_graph(self, WebApiAuth, add_dataset):
kb_id = add_dataset
res = knowledge_graph(WebApiAuth, kb_id)
assert res["code"] == 0, res
assert isinstance(res["data"], dict), res
assert "graph" in res["data"], res
assert "mind_map" in res["data"], res
@pytest.mark.p2
def test_delete_knowledge_graph(self, WebApiAuth, add_dataset):
kb_id = add_dataset
res = delete_knowledge_graph(WebApiAuth, kb_id)
assert res["code"] == 0, res
assert res["data"] is True, res
class TestKbTagsMetaNegative:
@pytest.mark.p3
@ -249,3 +281,15 @@ class TestKbTagsMetaNegative:
assert res["code"] == 101, res
assert "required argument are missing" in res["message"], res
assert "metadata" in res["message"], res
@pytest.mark.p3
def test_knowledge_graph_invalid_kb(self, WebApiAuth):
res = knowledge_graph(WebApiAuth, "invalid_kb_id")
assert res["code"] == 109, res
assert "No authorization" in res["message"], res
@pytest.mark.p3
def test_delete_knowledge_graph_invalid_kb(self, WebApiAuth):
res = delete_knowledge_graph(WebApiAuth, "invalid_kb_id")
assert res["code"] == 109, res
assert "No authorization" in res["message"], res

View File

@ -182,3 +182,20 @@ class TestDatasetsList:
res = list_kbs(WebApiAuth, params)
assert res["code"] == 0, res
assert len(res["data"]["kbs"]) == expected_page_size, res
@pytest.mark.p2
def test_owner_ids_payload_mode(self, WebApiAuth):
base_res = list_kbs(WebApiAuth, {"page_size": 10})
assert base_res["code"] == 0, base_res
assert base_res["data"]["kbs"], base_res
owner_id = base_res["data"]["kbs"][0]["tenant_id"]
res = list_kbs(
WebApiAuth,
params={"page": 1, "page_size": 2, "desc": "false"},
payload={"owner_ids": [owner_id]},
)
assert res["code"] == 0, res
assert res["data"]["total"] >= len(res["data"]["kbs"]), res
assert len(res["data"]["kbs"]) <= 2, res
assert all(kb["tenant_id"] == owner_id for kb in res["data"]["kbs"]), res

View File

@ -0,0 +1,290 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _ExprField:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return (self.name, other)
class _DummyTenantLLMModel:
tenant_id = _ExprField("tenant_id")
llm_factory = _ExprField("llm_factory")
class _TenantLLMRow:
def __init__(self, *, llm_name, llm_factory, model_type, api_key="key", status="1"):
self.llm_name = llm_name
self.llm_factory = llm_factory
self.model_type = model_type
self.api_key = api_key
self.status = status
def to_dict(self):
return {
"llm_name": self.llm_name,
"llm_factory": self.llm_factory,
"model_type": self.model_type,
"status": self.status,
}
class _LLMRow:
def __init__(self, *, llm_name, fid, model_type, status="1"):
self.llm_name = llm_name
self.fid = fid
self.model_type = model_type
self.status = status
def to_dict(self):
return {
"llm_name": self.llm_name,
"fid": self.fid,
"model_type": self.model_type,
"status": self.status,
}
def _run(coro):
return asyncio.run(coro)
def _load_llm_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
quart_mod = ModuleType("quart")
quart_mod.request = SimpleNamespace(args={})
monkeypatch.setitem(sys.modules, "quart", quart_mod)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.login_required = lambda fn: fn
apps_mod.current_user = SimpleNamespace(id="tenant-1")
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
tenant_llm_mod = ModuleType("api.db.services.tenant_llm_service")
class _StubLLMFactoriesService:
@staticmethod
def query(**_kwargs):
return []
class _StubTenantLLMService:
@staticmethod
def ensure_mineru_from_env(_tenant_id):
return None
@staticmethod
def query(**_kwargs):
return []
@staticmethod
def get_my_llms(_tenant_id):
return []
@staticmethod
def save(**_kwargs):
return True
@staticmethod
def filter_delete(_filters):
return True
tenant_llm_mod.LLMFactoriesService = _StubLLMFactoriesService
tenant_llm_mod.TenantLLMService = _StubTenantLLMService
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_mod)
llm_service_mod = ModuleType("api.db.services.llm_service")
class _StubLLMService:
@staticmethod
def get_all():
return []
@staticmethod
def query(**_kwargs):
return []
llm_service_mod.LLMService = _StubLLMService
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
api_utils_mod.get_allowed_llm_factories = lambda: []
api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
"code": code,
"message": message,
"data": data,
}
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
"code": code,
"message": message,
"data": data,
}
async def _get_request_json():
return {}
api_utils_mod.get_request_json = _get_request_json
api_utils_mod.server_error_response = lambda exc: {"code": 500, "message": str(exc), "data": None}
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
constants_mod = ModuleType("common.constants")
constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value="1"), INVALID=SimpleNamespace(value="0"))
constants_mod.LLMType = SimpleNamespace(
CHAT="chat",
EMBEDDING="embedding",
SPEECH2TEXT="speech2text",
IMAGE2TEXT="image2text",
RERANK="rerank",
TTS="tts",
OCR="ocr",
)
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
db_models_mod = ModuleType("api.db.db_models")
db_models_mod.TenantLLM = _DummyTenantLLMModel
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
base64_mod = ModuleType("rag.utils.base64_image")
base64_mod.test_image = lambda _s: _s
monkeypatch.setitem(sys.modules, "rag.utils.base64_image", base64_mod)
rag_llm_mod = ModuleType("rag.llm")
rag_llm_mod.EmbeddingModel = {}
rag_llm_mod.ChatModel = {}
rag_llm_mod.RerankModel = {}
rag_llm_mod.CvModel = {}
rag_llm_mod.TTSModel = {}
rag_llm_mod.OcrModel = {}
rag_llm_mod.Seq2txtModel = {}
monkeypatch.setitem(sys.modules, "rag.llm", rag_llm_mod)
module_path = repo_root / "api" / "apps" / "llm_app.py"
spec = importlib.util.spec_from_file_location("test_llm_list_unit_module", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_list_app_grouping_availability_and_merge(monkeypatch):
module = _load_llm_app(monkeypatch)
ensure_calls = []
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda tenant_id: ensure_calls.append(tenant_id))
tenant_rows = [
_TenantLLMRow(llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
_TenantLLMRow(llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
]
monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: tenant_rows)
all_llms = [
_LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"),
_LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"),
_LLMRow(llm_name="not-in-status", fid="Other", model_type="chat", status="1"),
]
monkeypatch.setattr(module.LLMService, "get_all", lambda: all_llms)
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
monkeypatch.setenv("COMPOSE_PROFILES", "tei-cpu")
monkeypatch.setenv("TEI_MODEL", "tei-embed")
res = _run(module.list_app())
assert res["code"] == 0
assert ensure_calls == ["tenant-1"]
data = res["data"]
assert {"Builtin", "FastEmbed", "CustomFactory"}.issubset(set(data.keys()))
builtin = data["Builtin"][0]
assert builtin["llm_name"] == "tei-embed"
assert builtin["available"] is True
fastembed = data["FastEmbed"][0]
assert fastembed["llm_name"] == "fast-emb"
assert fastembed["available"] is True
tenant_only = data["CustomFactory"][0]
assert tenant_only["llm_name"] == "tenant-only"
assert tenant_only["available"] is True
@pytest.mark.p2
def test_list_app_model_type_filter(monkeypatch):
module = _load_llm_app(monkeypatch)
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: None)
monkeypatch.setattr(
module.TenantLLMService,
"query",
lambda **_kwargs: [
_TenantLLMRow(llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
_TenantLLMRow(llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
],
)
monkeypatch.setattr(
module.LLMService,
"get_all",
lambda: [
_LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"),
_LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"),
],
)
monkeypatch.setattr(module, "request", SimpleNamespace(args={"model_type": "chat"}))
res = _run(module.list_app())
assert res["code"] == 0
assert list(res["data"].keys()) == ["CustomFactory"]
assert res["data"]["CustomFactory"][0]["model_type"] == "chat"
@pytest.mark.p2
def test_list_app_exception_path(monkeypatch):
module = _load_llm_app(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: None)
monkeypatch.setattr(
module.TenantLLMService,
"query",
lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("query boom")),
)
res = _run(module.list_app())
assert res["code"] == 500
assert "query boom" in res["message"]

View File

@ -0,0 +1,708 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import inspect
import sys
from functools import wraps
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _Field:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return (self.name, other)
class _DummyMCPServer:
id = _Field("id")
tenant_id = _Field("tenant_id")
def __init__(self, **kwargs):
self.id = kwargs.get("id", "")
self.name = kwargs.get("name", "")
self.url = kwargs.get("url", "")
self.server_type = kwargs.get("server_type", "sse")
self.tenant_id = kwargs.get("tenant_id", "tenant_1")
self.variables = kwargs.get("variables", {})
self.headers = kwargs.get("headers", {})
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"url": self.url,
"server_type": self.server_type,
"tenant_id": self.tenant_id,
"variables": self.variables,
"headers": self.headers,
}
class _DummyMCPServerService:
@staticmethod
def get_servers(*_args, **_kwargs):
return []
@staticmethod
def get_or_none(*_args, **_kwargs):
return None
@staticmethod
def get_by_id(*_args, **_kwargs):
return False, None
@staticmethod
def get_by_name_and_tenant(*_args, **_kwargs):
return False, None
@staticmethod
def insert(**_kwargs):
return True
@staticmethod
def filter_update(*_args, **_kwargs):
return True
@staticmethod
def delete_by_ids(*_args, **_kwargs):
return True
class _DummyTenantService:
@staticmethod
def get_by_id(*_args, **_kwargs):
return True, SimpleNamespace(id="tenant_1")
class _DummyTool:
def __init__(self, name):
self._name = name
def model_dump(self):
return {"name": self._name}
class _DummyMCPToolCallSession:
def __init__(self, _mcp_server, _variables):
self._tools = [_DummyTool("tool_a"), _DummyTool("tool_b")]
def get_tools(self, _timeout):
return self._tools
def tool_call(self, _name, _arguments, _timeout):
return "ok"
def _run(coro):
return asyncio.run(coro)
def _set_request_json(monkeypatch, module, payload):
async def _request_json():
return payload
monkeypatch.setattr(module, "get_request_json", _request_json)
def _load_mcp_server_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.current_user = SimpleNamespace(id="tenant_1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
db_models_mod = ModuleType("api.db.db_models")
db_models_mod.MCPServer = _DummyMCPServer
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
mcp_service_mod = ModuleType("api.db.services.mcp_server_service")
mcp_service_mod.MCPServerService = _DummyMCPServerService
monkeypatch.setitem(sys.modules, "api.db.services.mcp_server_service", mcp_service_mod)
user_service_mod = ModuleType("api.db.services.user_service")
user_service_mod.TenantService = _DummyTenantService
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
mcp_conn_mod = ModuleType("common.mcp_tool_call_conn")
mcp_conn_mod.MCPToolCallSession = _DummyMCPToolCallSession
mcp_conn_mod.close_multiple_mcp_toolcall_sessions = lambda _sessions: None
monkeypatch.setitem(sys.modules, "common.mcp_tool_call_conn", mcp_conn_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
async def _default_request_json():
return {}
def _get_json_result(code=0, message="success", data=None):
return {"code": code, "message": message, "data": data}
def _get_data_error_result(code=102, message="Sorry! Data missing!"):
return {"code": code, "message": message}
def _server_error_response(error):
return {"code": 100, "message": repr(error)}
async def _get_mcp_tools(*_args, **_kwargs):
return {}
def _validate_request(*_args, **_kwargs):
def _decorator(func):
@wraps(func)
async def _wrapped(*func_args, **func_kwargs):
if inspect.iscoroutinefunction(func):
return await func(*func_args, **func_kwargs)
return func(*func_args, **func_kwargs)
return _wrapped
return _decorator
api_utils_mod.get_request_json = _default_request_json
api_utils_mod.get_json_result = _get_json_result
api_utils_mod.get_data_error_result = _get_data_error_result
api_utils_mod.server_error_response = _server_error_response
api_utils_mod.validate_request = _validate_request
api_utils_mod.get_mcp_tools = _get_mcp_tools
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
module_name = "test_mcp_server_app_unit_module"
module_path = repo_root / "api" / "apps" / "mcp_server_app.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_list_mcp_desc_pagination_and_exception(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
monkeypatch.setattr(
module,
"request",
SimpleNamespace(args={"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"}),
)
_set_request_json(monkeypatch, module, {"mcp_ids": []})
monkeypatch.setattr(module.MCPServerService, "get_servers", lambda *_args, **_kwargs: [{"id": "a"}, {"id": "b"}])
res = _run(module.list_mcp())
assert res["code"] == 0
assert res["data"]["total"] == 2
assert res["data"]["mcp_servers"] == [{"id": "b"}]
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
_set_request_json(monkeypatch, module, {"mcp_ids": []})
def _raise_list(*_args, **_kwargs):
raise RuntimeError("list explode")
monkeypatch.setattr(module.MCPServerService, "get_servers", _raise_list)
res = _run(module.list_mcp())
assert res["code"] == 100
assert "list explode" in res["message"]
@pytest.mark.p2
def test_detail_not_found_success_and_exception(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args={"mcp_id": "mcp-1"}))
monkeypatch.setattr(module.MCPServerService, "get_or_none", lambda **_kwargs: None)
res = module.detail()
assert res["code"] == module.RetCode.NOT_FOUND
monkeypatch.setattr(
module.MCPServerService,
"get_or_none",
lambda **_kwargs: _DummyMCPServer(id="mcp-1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1"),
)
res = module.detail()
assert res["code"] == 0
assert res["data"]["id"] == "mcp-1"
def _raise_detail(**_kwargs):
raise RuntimeError("detail explode")
monkeypatch.setattr(module.MCPServerService, "get_or_none", _raise_detail)
res = module.detail()
assert res["code"] == 100
assert "detail explode" in res["message"]
@pytest.mark.p2
def test_create_validation_guards(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
_set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "invalid"})
res = _run(module.create.__wrapped__())
assert "Unsupported MCP server type" in res["message"]
_set_request_json(monkeypatch, module, {"name": "", "url": "http://a", "server_type": "sse"})
res = _run(module.create.__wrapped__())
assert "Invalid MCP name" in res["message"]
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (True, object()))
_set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "sse"})
res = _run(module.create.__wrapped__())
assert "Duplicated MCP server name" in res["message"]
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
_set_request_json(monkeypatch, module, {"name": "srv", "url": "", "server_type": "sse"})
res = _run(module.create.__wrapped__())
assert "Invalid url" in res["message"]
@pytest.mark.p2
def test_create_service_paths(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
base_payload = {
"name": "srv",
"url": "http://server",
"server_type": "sse",
"headers": '{"Authorization": "x"}',
"variables": '{"tools": {"old": 1}, "token": "abc"}',
"timeout": "2.5",
}
monkeypatch.setattr(module, "get_uuid", lambda: "uuid-create")
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (False, None))
res = _run(module.create.__wrapped__())
assert "Tenant not found" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (True, object()))
async def _thread_pool_tools_error(_func, _servers, _timeout):
return None, "tools error"
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
res = _run(module.create.__wrapped__())
assert res["code"] == "tools error"
assert "Sorry! Data missing!" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
async def _thread_pool_ok(_func, servers, _timeout):
return {servers[0].name: [{"name": "tool_a"}, {"invalid": True}]}, None
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: False)
res = _run(module.create.__wrapped__())
assert res["code"] == "Failed to create MCP server."
assert "Sorry! Data missing!" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: True)
res = _run(module.create.__wrapped__())
assert res["code"] == 0
assert res["data"]["id"] == "uuid-create"
assert res["data"]["tenant_id"] == "tenant_1"
assert res["data"]["variables"]["tools"] == {"tool_a": {"name": "tool_a"}}
_set_request_json(monkeypatch, module, dict(base_payload))
async def _thread_pool_raises(_func, _servers, _timeout):
raise RuntimeError("create explode")
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
res = _run(module.create.__wrapped__())
assert res["code"] == 100
assert "create explode" in res["message"]
@pytest.mark.p2
def test_update_validation_guards(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
existing = _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="tenant_1", variables={}, headers={})
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
res = _run(module.update.__wrapped__())
assert "Cannot find MCP server" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
monkeypatch.setattr(
module.MCPServerService,
"get_by_id",
lambda _mcp_id: (True, _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="other", variables={}, headers={})),
)
res = _run(module.update.__wrapped__())
assert "Cannot find MCP server" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "server_type": "invalid"})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
res = _run(module.update.__wrapped__())
assert "Unsupported MCP server type" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "name": "a" * 256})
res = _run(module.update.__wrapped__())
assert "Invalid MCP name" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": ""})
res = _run(module.update.__wrapped__())
assert "Invalid url" in res["message"]
@pytest.mark.p2
def test_update_service_paths(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
existing = _DummyMCPServer(
id="mcp-1",
name="srv",
url="http://server",
server_type="sse",
tenant_id="tenant_1",
variables={"tools": {"old": {"enabled": True}}, "token": "abc"},
headers={"Authorization": "old"},
)
updated = _DummyMCPServer(
id="mcp-1",
name="srv-new",
url="http://server-new",
server_type="sse",
tenant_id="tenant_1",
variables={"tools": {"tool_a": {"name": "tool_a"}}},
headers={"Authorization": "new"},
)
base_payload = {
"mcp_id": "mcp-1",
"name": "srv-new",
"url": "http://server-new",
"server_type": "sse",
"headers": '{"Authorization": "new"}',
"variables": '{"tools": {"ignore": 1}, "token": "new"}',
"timeout": "3.0",
}
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
async def _thread_pool_tools_error(_func, _servers, _timeout):
return None, "update tools error"
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
res = _run(module.update.__wrapped__())
assert res["code"] == "update tools error"
assert "Sorry! Data missing!" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
async def _thread_pool_ok(_func, servers, _timeout):
return {servers[0].name: [{"name": "tool_a"}, {"bad": True}]}, None
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False)
res = _run(module.update.__wrapped__())
assert "Failed to updated MCP server" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True)
def _get_by_id_fetch_fail(_mcp_id):
if _get_by_id_fetch_fail.calls == 0:
_get_by_id_fetch_fail.calls += 1
return True, existing
return False, None
_get_by_id_fetch_fail.calls = 0
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_fetch_fail)
res = _run(module.update.__wrapped__())
assert "Failed to fetch updated MCP server" in res["message"]
_set_request_json(monkeypatch, module, dict(base_payload))
def _get_by_id_success(_mcp_id):
if _get_by_id_success.calls == 0:
_get_by_id_success.calls += 1
return True, existing
return True, updated
_get_by_id_success.calls = 0
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_success)
res = _run(module.update.__wrapped__())
assert res["code"] == 0
assert res["data"]["id"] == "mcp-1"
_set_request_json(monkeypatch, module, dict(base_payload))
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
async def _thread_pool_raises(_func, _servers, _timeout):
raise RuntimeError("update explode")
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
res = _run(module.update.__wrapped__())
assert res["code"] == 100
assert "update explode" in res["message"]
@pytest.mark.p2
def test_rm_failure_success_and_exception(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: False)
res = _run(module.rm.__wrapped__())
assert "Failed to delete MCP servers" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: True)
res = _run(module.rm.__wrapped__())
assert res["code"] == 0
assert res["data"] is True
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
def _raise_rm(_ids):
raise RuntimeError("rm explode")
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", _raise_rm)
res = _run(module.rm.__wrapped__())
assert res["code"] == 100
assert "rm explode" in res["message"]
@pytest.mark.p2
def test_import_multiple_missing_servers_and_exception(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcpServers": {}})
res = _run(module.import_multiple.__wrapped__())
assert "No MCP servers provided" in res["message"]
_set_request_json(monkeypatch, module, {"mcpServers": {"srv": {"type": "sse", "url": "http://x"}}, "timeout": "1"})
def _raise_import(**_kwargs):
raise RuntimeError("import explode")
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _raise_import)
res = _run(module.import_multiple.__wrapped__())
assert res["code"] == 100
assert "import explode" in res["message"]
@pytest.mark.p2
def test_import_multiple_mixed_results(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
payload = {
"mcpServers": {
"missing_fields": {"type": "sse"},
"": {"type": "sse", "url": "http://empty"},
"dup": {"type": "sse", "url": "http://dup", "authorization_token": "dup-token"},
"tool_err": {"type": "sse", "url": "http://err"},
"insert_fail": {"type": "sse", "url": "http://fail"},
},
"timeout": "3",
}
_set_request_json(monkeypatch, module, payload)
monkeypatch.setattr(module, "get_uuid", lambda: "uuid-import")
def _get_by_name_and_tenant(name, tenant_id):
if name == "dup" and not _get_by_name_and_tenant.first_dup_seen:
_get_by_name_and_tenant.first_dup_seen = True
return True, object()
return False, None
_get_by_name_and_tenant.first_dup_seen = False
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _get_by_name_and_tenant)
async def _thread_pool_exec(func, servers, _timeout):
mcp_server = servers[0]
if mcp_server.name == "tool_err":
return None, "tool call failed"
return {mcp_server.name: [{"name": "tool_a"}, {"invalid": True}]}, None
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec)
def _insert(**kwargs):
return kwargs["name"] != "insert_fail"
monkeypatch.setattr(module.MCPServerService, "insert", _insert)
res = _run(module.import_multiple.__wrapped__())
assert res["code"] == 0
results = {item["server"]: item for item in res["data"]["results"]}
assert results["missing_fields"]["success"] is False
assert "Missing required fields" in results["missing_fields"]["message"]
assert results[""]["success"] is False
assert "Invalid MCP name" in results[""]["message"]
assert results["tool_err"]["success"] is False
assert "tool call failed" in results["tool_err"]["message"]
assert results["insert_fail"]["success"] is False
assert "Failed to create MCP server" in results["insert_fail"]["message"]
assert results["dup"]["success"] is True
assert results["dup"]["new_name"] == "dup_0"
assert "Renamed from 'dup' to 'dup_0' avoid duplication" == results["dup"]["message"]
@pytest.mark.p2
def test_export_multiple_missing_ids_success_and_exception(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcp_ids": []})
res = _run(module.export_multiple.__wrapped__())
assert "No MCP server IDs provided" in res["message"]
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1", "id2", "id3"]})
def _get_by_id(mcp_id):
if mcp_id == "id1":
return True, _DummyMCPServer(
id="id1",
name="srv-one",
url="http://one",
server_type="sse",
tenant_id="tenant_1",
variables={"authorization_token": "tok", "tools": {"tool_a": {"enabled": True}}},
)
if mcp_id == "id2":
return True, _DummyMCPServer(
id="id2",
name="srv-two",
url="http://two",
server_type="sse",
tenant_id="other",
variables={},
)
return False, None
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id)
res = _run(module.export_multiple.__wrapped__())
assert res["code"] == 0
assert list(res["data"]["mcpServers"].keys()) == ["srv-one"]
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"]})
def _raise_export(_mcp_id):
raise RuntimeError("export explode")
monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_export)
res = _run(module.export_multiple.__wrapped__())
assert res["code"] == 100
assert "export explode" in res["message"]
@pytest.mark.p2
def test_list_tools_missing_ids_success_inner_error_outer_error_and_finally_cleanup(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcp_ids": []})
res = _run(module.list_tools.__wrapped__())
assert "No MCP server IDs provided" in res["message"]
server = _DummyMCPServer(
id="id1",
name="srv-tools",
url="http://tools",
server_type="sse",
tenant_id="tenant_1",
variables={"tools": {"tool_a": {"enabled": False}}},
)
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"})
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server))
close_calls = []
async def _thread_pool_exec_success(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls.append(args[0])
return None
return func(*args)
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
res = _run(module.list_tools.__wrapped__())
assert res["code"] == 0
assert res["data"]["id1"][0]["name"] == "tool_a"
assert res["data"]["id1"][0]["enabled"] is False
assert res["data"]["id1"][1]["enabled"] is True
assert close_calls and len(close_calls[-1]) == 1
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"})
close_calls_inner = []
async def _thread_pool_exec_inner_error(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls_inner.append(args[0])
return None
raise RuntimeError("inner tools explode")
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error)
res = _run(module.list_tools.__wrapped__())
assert res["code"] == 102
assert "MCP list tools error" in res["message"]
assert close_calls_inner and len(close_calls_inner[-1]) == 1
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"})
close_calls_outer = []
def _raise_get_by_id(_mcp_id):
raise RuntimeError("outer explode")
monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_get_by_id)
async def _thread_pool_exec_outer(func, *args):
if func is module.close_multiple_mcp_toolcall_sessions:
close_calls_outer.append(args[0])
return None
return func(*args)
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_outer)
res = _run(module.list_tools.__wrapped__())
assert res["code"] == 100
assert "outer explode" in res["message"]
assert close_calls_outer
@pytest.mark.p2
def test_test_tool_missing_mcp_id(monkeypatch):
module = _load_mcp_server_app(monkeypatch)
_set_request_json(monkeypatch, module, {"mcp_id": "", "tool_name": "tool_a", "arguments": {"x": 1}})
res = _run(module.test_tool.__wrapped__())
assert "No MCP server ID provided" in res["message"]

View File

@ -20,8 +20,6 @@ import pytest
from test_web_api.common import create_memory
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth
from hypothesis import example, given, settings
from utils.hypothesis_utils import valid_names
class TestAuthorization:
@ -42,9 +40,7 @@ class TestAuthorization:
class TestMemoryCreate:
@pytest.mark.p1
@given(name=valid_names())
@example("d" * 128)
@settings(max_examples=20)
@pytest.mark.parametrize("name", ["test_memory_name", "d" * 128])
def test_name(self, WebApiAuth, name):
payload = {
"name": name,
@ -79,7 +75,7 @@ class TestMemoryCreate:
assert res["message"] == expected_message, res
@pytest.mark.p2
@given(name=valid_names())
@pytest.mark.parametrize("name", ["invalid_type_name", "memory_alpha"])
def test_type_invalid(self, WebApiAuth, name):
payload = {
"name": name,

View File

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import pytest
from test_web_api.common import update_memory
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth
from hypothesis import HealthCheck, example, given, settings
from utils import encode_avatar
from utils.file_utils import create_image_file
from utils.hypothesis_utils import valid_names
class TestAuthorization:
@ -42,15 +42,14 @@ class TestAuthorization:
class TestMemoryUpdate:
@pytest.mark.p1
@given(name=valid_names())
@example("f" * 128)
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
@pytest.mark.parametrize("name", ["updated_memory", "f" * 128])
def test_name(self, WebApiAuth, add_memory_func, name):
memory_ids = add_memory_func
payload = {"name": name}
res = update_memory(WebApiAuth, memory_ids[0], payload)
assert res["code"] == 0, res
assert res["data"]["name"] == name, res
pattern = rf"^{re.escape(name)}(?:\(\d+\))?$"
assert re.match(pattern, res["data"]["name"]), res
@pytest.mark.p2
@pytest.mark.parametrize(

View File

@ -71,6 +71,20 @@ Are you asking about the fruit itself, or its use in a specific context?
assert message["agent_id"] == agent_id, message
assert message["session_id"] == session_id, message
@pytest.mark.p2
def test_add_message_invalid_memory_id(self, WebApiAuth):
message_payload = {
"memory_id": ["missing_memory_id"],
"agent_id": uuid.uuid4().hex,
"session_id": uuid.uuid4().hex,
"user_id": "",
"user_input": "what is pineapple?",
"agent_response": "pineapple response",
}
res = add_message(WebApiAuth, message_payload)
assert res["code"] == 500, res
assert "Some messages failed to add" in res["message"], res
@pytest.mark.usefixtures("add_empty_multiple_type_memory")
class TestAddMultipleTypeMessage:

View File

@ -15,8 +15,9 @@
#
import random
import pytest
import requests
from test_web_api.common import forget_message, list_memory_message, get_message_content
from configs import INVALID_API_TOKEN
from configs import HOST_ADDRESS, INVALID_API_TOKEN, VERSION
from libs.auth import RAGFlowWebApiAuth
@ -52,3 +53,17 @@ class TestForgetMessage:
forgot_message_res = get_message_content(WebApiAuth, memory_id, message["message_id"])
assert forgot_message_res["code"] == 0, forgot_message_res
assert forgot_message_res["data"]["forget_at"] not in ["-", ""], forgot_message_res
@pytest.mark.p2
def test_forget_message_invalid_memory_id(self, WebApiAuth):
res = forget_message(WebApiAuth, "missing_memory_id", 1)
assert res["code"] == 404, res
assert "not found" in res["message"].lower(), res
@pytest.mark.p2
def test_forget_message_invalid_message_id(self, WebApiAuth):
memory_id = self.memory_id
url = f"{HOST_ADDRESS}/api/{VERSION}/messages/{memory_id}:invalid_message_id"
res = requests.delete(url=url, headers={"Content-Type": "application/json"}, auth=WebApiAuth).json()
assert res["code"] == 500, res
assert "Internal server error" in res["message"], res

View File

@ -49,3 +49,16 @@ class TestGetMessageContent:
for field in ["content", "content_embed"]:
assert field in content_res["data"]
assert content_res["data"][field] is not None, content_res
@pytest.mark.p2
def test_get_message_content_invalid_memory_id(self, WebApiAuth):
res = get_message_content(WebApiAuth, "missing_memory_id", 1)
assert res["code"] == 404, res
assert "not found" in res["message"].lower(), res
@pytest.mark.p2
def test_get_message_content_invalid_message_id(self, WebApiAuth):
memory_id = self.memory_id
res = get_message_content(WebApiAuth, memory_id, 999999999)
assert res["code"] == 404, res
assert "not found" in res["message"].lower(), res

View File

@ -66,3 +66,15 @@ class TestGetRecentMessage:
for message in res["data"]:
assert message["session_id"] == session_id, message
@pytest.mark.p2
def test_get_recent_messages_missing_memory_id(self, WebApiAuth):
res = get_recent_message(WebApiAuth, params={})
assert res["code"] == 101, res
assert "memory_ids is required" in res["message"], res
@pytest.mark.p2
def test_get_recent_messages_csv_memory_ids(self, WebApiAuth):
memory_id = self.memory_id
res = get_recent_message(WebApiAuth, params={"memory_id": f"{memory_id},{memory_id}"})
assert res["code"] == 0, res
assert isinstance(res["data"], list), res

View File

@ -0,0 +1,151 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import inspect
import sys
from copy import deepcopy
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
class _AwaitableValue:
def __init__(self, value):
self._value = value
def __await__(self):
async def _co():
return self._value
return _co().__await__()
class _DummyArgs(dict):
def getlist(self, key):
value = self.get(key)
if value is None:
return []
if isinstance(value, list):
return value
return [value]
class _DummyMemoryApiService:
async def add_message(self, *_args, **_kwargs):
return True, "ok"
async def get_messages(self, *_args, **_kwargs):
return []
def _run(coro):
return asyncio.run(coro)
def _load_memory_routes_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
apps_mod = ModuleType("api.apps")
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
apps_mod.current_user = SimpleNamespace(id="user-1")
apps_mod.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
services_mod = ModuleType("api.apps.services")
services_mod.memory_api_service = _DummyMemoryApiService()
monkeypatch.setitem(sys.modules, "api.apps.services", services_mod)
module_name = "test_message_routes_unit_module"
module_path = repo_root / "api" / "apps" / "restful_apis" / "memory_api.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module
def _set_request_json(monkeypatch, module, payload):
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload)))
@pytest.mark.p2
def test_add_message_partial_failure_branch(monkeypatch):
module = _load_memory_routes_module(monkeypatch)
_set_request_json(
monkeypatch,
module,
{
"memory_id": ["memory-1"],
"agent_id": "agent-1",
"session_id": "session-1",
"user_input": "hello",
"agent_response": "world",
},
)
async def _add_message(_memory_ids, _message_dict):
return False, "cannot enqueue"
monkeypatch.setattr(module.memory_api_service, "add_message", _add_message)
res = _run(inspect.unwrap(module.add_message)())
assert res["code"] == module.RetCode.SERVER_ERROR, res
assert "Some messages failed to add" in res["message"], res
@pytest.mark.p2
def test_get_messages_csv_and_missing_memory_ids(monkeypatch):
module = _load_memory_routes_module(monkeypatch)
monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs({})))
res = _run(inspect.unwrap(module.get_messages)())
assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
assert "memory_ids is required." in res["message"], res
monkeypatch.setattr(
module,
"request",
SimpleNamespace(args=_DummyArgs({"memory_id": "m1,m2", "agent_id": "a1", "session_id": "s1", "limit": "5"})),
)
async def _get_messages(memory_ids, agent_id, session_id, limit):
assert memory_ids == ["m1", "m2"]
assert agent_id == "a1"
assert session_id == "s1"
assert limit == 5
return [{"message_id": 1}]
monkeypatch.setattr(module.memory_api_service, "get_messages", _get_messages)
res = _run(inspect.unwrap(module.get_messages)())
assert res["code"] == module.RetCode.SUCCESS, res
assert isinstance(res["data"], list), res

View File

@ -80,3 +80,23 @@ class TestSearchMessage:
assert res["code"] == 0, res
assert len(res["data"]) > 0
assert len(res["data"]) <= params["top_n"]
@pytest.mark.p2
def test_query_missing_query(self, WebApiAuth):
memory_id = self.memory_id
res = search_message(WebApiAuth, {"memory_id": memory_id})
assert res["code"] in [100, 500], res
@pytest.mark.p2
def test_query_missing_memory_id(self, WebApiAuth):
res = search_message(WebApiAuth, {"query": "what is coriander"})
assert res["code"] == 0, res
assert isinstance(res["data"], list), res
@pytest.mark.p2
def test_query_with_csv_memory_ids(self, WebApiAuth):
memory_id = self.memory_id
query = "Coriander is a versatile herb."
res = search_message(WebApiAuth, {"memory_id": f"{memory_id},{memory_id}", "query": query})
assert res["code"] == 0, res
assert isinstance(res["data"], list), res

View File

@ -16,9 +16,11 @@
import random
import pytest
import requests
from test_web_api.common import update_message_status, list_memory_message, get_message_content
from configs import INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth
from configs import HOST_ADDRESS, VERSION
class TestAuthorization:
@ -73,3 +75,34 @@ class TestUpdateMessageStatus:
res = get_message_content(WebApiAuth, memory_id, message["message_id"])
assert res["code"] == 0, res
assert res["data"]["status"], res
@pytest.mark.p2
def test_update_invalid_status_type(self, WebApiAuth):
memory_id = self.memory_id
list_res = list_memory_message(WebApiAuth, memory_id)
assert list_res["code"] == 0, list_res
message_id = list_res["data"]["messages"]["message_list"][0]["message_id"]
url = f"{HOST_ADDRESS}/api/{VERSION}/messages/{memory_id}:{message_id}"
res = requests.put(url=url, headers={"Content-Type": "application/json"}, auth=WebApiAuth, json={"status": "false"}).json()
assert res["code"] == 101, res
assert "Status must be a boolean." in res["message"], res
@pytest.mark.p2
def test_update_invalid_memory_id(self, WebApiAuth):
res = update_message_status(WebApiAuth, "missing_memory_id", 1, False)
assert res["code"] == 404, res
assert "not found" in res["message"].lower(), res
@pytest.mark.p2
def test_update_invalid_message_id(self, WebApiAuth):
memory_id = self.memory_id
url = f"{HOST_ADDRESS}/api/{VERSION}/messages/{memory_id}:invalid_message_id"
res = requests.put(
url=url,
headers={"Content-Type": "application/json"},
auth=WebApiAuth,
json={"status": True},
).json()
assert res["code"] == 500, res
assert "Internal server error" in res["message"], res

View File

@ -13,6 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib.util
import sys
from pathlib import Path
from types import ModuleType
import pytest
from common import plugin_llm_tools
from configs import INVALID_API_TOKEN
@ -40,3 +45,54 @@ class TestPluginTools:
res = plugin_llm_tools(WebApiAuth)
assert res["code"] == 0, res
assert isinstance(res["data"], list), res
class _DummyManager:
def route(self, *_args, **_kwargs):
def decorator(func):
return func
return decorator
def _load_plugin_app(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
stub_apps = ModuleType("api.apps")
stub_apps.login_required = lambda func: func
monkeypatch.setitem(sys.modules, "api.apps", stub_apps)
stub_plugin = ModuleType("agent.plugin")
class _StubGlobalPluginManager:
@staticmethod
def get_llm_tools():
return []
stub_plugin.GlobalPluginManager = _StubGlobalPluginManager
monkeypatch.setitem(sys.modules, "agent.plugin", stub_plugin)
module_path = Path(__file__).resolve().parents[4] / "api" / "apps" / "plugin_app.py"
spec = importlib.util.spec_from_file_location("test_plugin_app_unit", module_path)
module = importlib.util.module_from_spec(spec)
module.manager = _DummyManager()
spec.loader.exec_module(module)
return module
@pytest.mark.p2
def test_llm_tools_metadata_shape_unit(monkeypatch):
module = _load_plugin_app(monkeypatch)
class _DummyTool:
def get_metadata(self):
return {"name": "dummy", "description": "test"}
monkeypatch.setattr(module.GlobalPluginManager, "get_llm_tools", staticmethod(lambda: [_DummyTool()]))
res = module.llm_tools()
assert res["code"] == 0
assert isinstance(res["data"], list)
assert res["data"][0]["name"] == "dummy"
assert res["data"][0]["description"] == "test"

View File

@ -0,0 +1,242 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import importlib.util
import logging
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
class _DummyAPIToken:
@staticmethod
def query(**_kwargs):
return []
class _DummyUserService:
@staticmethod
def query(**_kwargs):
return []
def _run(coro):
return asyncio.run(coro)
def _load_apps_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)
settings_mod = ModuleType("common.settings")
settings_mod.SECRET_KEY = "test-secret-key"
settings_mod.init_settings = lambda: None
settings_mod.decrypt_database_config = lambda name=None: {}
monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
common_pkg.settings = settings_mod
db_models_mod = ModuleType("api.db.db_models")
db_models_mod.APIToken = _DummyAPIToken
db_models_mod.close_connection = lambda: None
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
services_mod = ModuleType("api.db.services")
services_mod.UserService = _DummyUserService
monkeypatch.setitem(sys.modules, "api.db.services", services_mod)
commands_mod = ModuleType("api.utils.commands")
commands_mod.register_commands = lambda _app: None
monkeypatch.setitem(sys.modules, "api.utils.commands", commands_mod)
api_utils_mod = ModuleType("api.utils.api_utils")
def _get_json_result(code=0, message="success", data=None):
return {"code": code, "message": message, "data": data}
def _server_error_response(error):
return {"code": 100, "message": repr(error)}
api_utils_mod.get_json_result = _get_json_result
api_utils_mod.server_error_response = _server_error_response
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
module_name = "test_apps_init_unit_module"
module_path = repo_root / "api" / "apps" / "__init__.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, module_name, module)
monkeypatch.setattr(Path, "glob", lambda self, _pattern: [])
spec.loader.exec_module(module)
return module.app, module
@pytest.mark.p2
def test_module_init_and_unauthorized_message_variants(monkeypatch):
_quart_app, apps_module = _load_apps_module(monkeypatch)
assert apps_module.client_urls_prefix == []
class _BrokenRepr:
def __repr__(self):
raise RuntimeError("repr explode")
class _ExactUnauthorizedRepr:
def __repr__(self):
return apps_module.UNAUTHORIZED_MESSAGE
class _Unauthorized401Repr:
def __repr__(self):
return "Unauthorized 401 from upstream"
class _OtherRepr:
def __repr__(self):
return "Forbidden 403"
assert apps_module._unauthorized_message(None) == apps_module.UNAUTHORIZED_MESSAGE
assert apps_module._unauthorized_message(_BrokenRepr()) == apps_module.UNAUTHORIZED_MESSAGE
assert apps_module._unauthorized_message(_ExactUnauthorizedRepr()) == apps_module.UNAUTHORIZED_MESSAGE
assert apps_module._unauthorized_message(_Unauthorized401Repr()) == "Unauthorized 401 from upstream"
assert apps_module._unauthorized_message(_OtherRepr()) == apps_module.UNAUTHORIZED_MESSAGE
@pytest.mark.p2
def test_load_user_token_edge_cases(monkeypatch):
quart_app, apps_module = _load_apps_module(monkeypatch)
user_with_empty_token = SimpleNamespace(email="empty@example.com", access_token="")
async def _case():
async with quart_app.test_request_context("/", headers={"Authorization": "token"}):
monkeypatch.setattr(apps_module.Serializer, "loads", lambda _self, _auth: "")
assert apps_module._load_user() is None
async with quart_app.test_request_context("/", headers={"Authorization": "token"}):
monkeypatch.setattr(apps_module.Serializer, "loads", lambda _self, _auth: "short-token")
assert apps_module._load_user() is None
async with quart_app.test_request_context("/", headers={"Authorization": "token"}):
monkeypatch.setattr(apps_module.Serializer, "loads", lambda _self, _auth: "a" * 32)
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kwargs: [user_with_empty_token])
assert apps_module._load_user() is None
_run(_case())
@pytest.mark.p2
def test_load_user_api_token_fallback_and_fallback_exception(monkeypatch, caplog):
quart_app, apps_module = _load_apps_module(monkeypatch)
def _raise_decode(_self, _auth):
raise RuntimeError("decode failed")
monkeypatch.setattr(apps_module.Serializer, "loads", _raise_decode)
fallback_user_empty_token = SimpleNamespace(email="fallback@example.com", access_token="")
async def _case():
monkeypatch.setattr(apps_module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kwargs: [fallback_user_empty_token])
async with quart_app.test_request_context("/", headers={"Authorization": "Bearer api-token"}):
assert apps_module._load_user() is None
def _raise_api_token(**_kwargs):
raise RuntimeError("api token fallback failed")
monkeypatch.setattr(apps_module.APIToken, "query", _raise_api_token)
async with quart_app.test_request_context("/", headers={"Authorization": "Bearer api-token"}):
with caplog.at_level(logging.WARNING):
assert apps_module._load_user() is None
_run(_case())
assert "api token fallback failed" in caplog.text
@pytest.mark.p2
def test_login_required_timing_and_login_user_inactive(monkeypatch, caplog):
quart_app, apps_module = _load_apps_module(monkeypatch)
monkeypatch.setenv("RAGFLOW_API_TIMING", "1")
monkeypatch.setattr(apps_module, "current_user", SimpleNamespace(id="tenant-1"))
@apps_module.login_required
async def _timed_handler():
return {"ok": True}
async def _case():
async with quart_app.test_request_context("/timed"):
with caplog.at_level(logging.INFO):
assert await _timed_handler() == {"ok": True}
inactive_user = SimpleNamespace(id="user-1", is_active=False)
assert apps_module.login_user(inactive_user) is False
_run(_case())
assert "api_timing login_required" in caplog.text
@pytest.mark.p2
def test_logout_user_not_found_and_unauthorized_handlers(monkeypatch):
quart_app, apps_module = _load_apps_module(monkeypatch)
async def _case():
async with quart_app.test_request_context("/logout", headers={"Cookie": "remember_token=abc"}):
from quart import session
session["_user_id"] = "user-1"
session["_fresh"] = True
session["_id"] = "session-id"
session["_remember_seconds"] = 5
assert apps_module.logout_user() is True
assert "_user_id" not in session
assert "_fresh" not in session
assert "_id" not in session
assert session.get("_remember") == "clear"
assert "_remember_seconds" not in session
async with quart_app.test_request_context("/missing/path"):
not_found_resp, status = await apps_module.not_found(RuntimeError("missing"))
assert status == apps_module.RetCode.NOT_FOUND
payload = await not_found_resp.get_json()
assert payload["code"] == apps_module.RetCode.NOT_FOUND
assert payload["error"] == "Not Found"
assert "Not Found:" in payload["message"]
async with quart_app.test_request_context("/protected"):
@apps_module.login_required
async def _protected():
return {"ok": True}
monkeypatch.setattr(apps_module, "current_user", None)
with pytest.raises(apps_module.QuartAuthUnauthorized) as exc_info:
await _protected()
quart_payload, quart_status = await apps_module.unauthorized_quart_auth(exc_info.value)
assert quart_status == apps_module.RetCode.UNAUTHORIZED
assert quart_payload["code"] == apps_module.RetCode.UNAUTHORIZED
werk_payload, werk_status = await apps_module.unauthorized_werkzeug(WerkzeugUnauthorized("Unauthorized 401"))
assert werk_status == apps_module.RetCode.UNAUTHORIZED
assert werk_payload["code"] == apps_module.RetCode.UNAUTHORIZED
_run(_case())