mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-03 08:47:48 +08:00
tests: improve RAGFlow coverage based on Codecov report (#13200)
### What problem does this PR solve? Codecov’s coverage report shows that several RAGFlow code paths are currently untested or under-tested. This makes it easier for regressions to slip in during refactors and feature work. This PR adds targeted automated tests to cover the files and branches highlighted by Codecov, improving confidence in core behavior while keeping runtime functionality unchanged. ### Type of change - [x] Other (please describe): Test coverage improvement (adds/extends unit and integration tests to address Codecov-reported gaps)
This commit is contained in:
@ -292,7 +292,7 @@ def knowledge_graph(auth, dataset_id, params=None, *, headers=HEADERS):
|
||||
|
||||
|
||||
def delete_knowledge_graph(auth, dataset_id, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.delete(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/delete_knowledge_graph", headers=headers, auth=auth, json=payload, data=data)
|
||||
res = requests.delete(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/knowledge_graph", headers=headers, auth=auth, json=payload, data=data)
|
||||
return res.json()
|
||||
|
||||
|
||||
@ -434,6 +434,11 @@ def update_chunk(auth, payload=None, *, headers=HEADERS):
|
||||
return res.json()
|
||||
|
||||
|
||||
def switch_chunks(auth, payload=None, *, headers=HEADERS):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/switch", headers=headers, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def delete_chunks(auth, payload=None, *, headers=HEADERS):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/rm", headers=headers, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
247
test/testcases/test_web_api/test_api_app/test_api_tokens_unit.py
Normal file
247
test/testcases/test_web_api/test_api_app/test_api_tokens_unit.py
Normal file
@ -0,0 +1,247 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _ExprField:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.name, other)
|
||||
|
||||
|
||||
class _DummyAPITokenModel:
|
||||
tenant_id = _ExprField("tenant_id")
|
||||
token = _ExprField("token")
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _load_api_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = SimpleNamespace(args={})
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.login_required = lambda fn: fn
|
||||
apps_mod.current_user = SimpleNamespace(id="user-1")
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
async def _get_request_json():
|
||||
return {}
|
||||
|
||||
api_utils_mod.generate_confirmation_token = lambda: "token-123"
|
||||
api_utils_mod.get_request_json = _get_request_json
|
||||
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
api_utils_mod.server_error_response = lambda exc: {
|
||||
"code": 500,
|
||||
"message": str(exc),
|
||||
"data": None,
|
||||
}
|
||||
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
api_service_mod = ModuleType("api.db.services.api_service")
|
||||
|
||||
class _StubAPITokenService:
|
||||
@staticmethod
|
||||
def save(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def filter_delete(_conds):
|
||||
return True
|
||||
|
||||
class _StubAPI4ConversationService:
|
||||
@staticmethod
|
||||
def stats(*_args, **_kwargs):
|
||||
return []
|
||||
|
||||
api_service_mod.APITokenService = _StubAPITokenService
|
||||
api_service_mod.API4ConversationService = _StubAPI4ConversationService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.api_service", api_service_mod)
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
|
||||
class _StubUserTenantService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return [SimpleNamespace(tenant_id="tenant-1")]
|
||||
|
||||
user_service_mod.UserTenantService = _StubUserTenantService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
|
||||
db_models_mod = ModuleType("api.db.db_models")
|
||||
db_models_mod.APIToken = _DummyAPITokenModel
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
time_utils_mod = ModuleType("common.time_utils")
|
||||
time_utils_mod.current_timestamp = lambda: 123
|
||||
time_utils_mod.datetime_format = lambda _dt: "2026-01-01 00:00:00"
|
||||
monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "api_app.py"
|
||||
spec = importlib.util.spec_from_file_location("test_api_tokens_unit_module", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_new_token_branches_and_error_paths(monkeypatch):
|
||||
module = _load_api_app(monkeypatch)
|
||||
|
||||
async def req_canvas():
|
||||
return {"canvas_id": "canvas-1"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", req_canvas)
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
|
||||
res = _run(module.new_token())
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(module.APITokenService, "save", lambda **_kwargs: True)
|
||||
res = _run(module.new_token())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["tenant_id"] == "tenant-1"
|
||||
assert res["data"]["dialog_id"] == "canvas-1"
|
||||
assert res["data"]["source"] == "agent"
|
||||
|
||||
monkeypatch.setattr(module.APITokenService, "save", lambda **_kwargs: False)
|
||||
res = _run(module.new_token())
|
||||
assert res["message"] == "Fail to new a dialog!"
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("query failed")))
|
||||
res = _run(module.new_token())
|
||||
assert res["code"] == 500
|
||||
assert "query failed" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_token_list_tenant_guard_and_exception(monkeypatch):
|
||||
module = _load_api_app(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={"dialog_id": "d1"}))
|
||||
res = module.token_list()
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
|
||||
res = module.token_list()
|
||||
assert res["code"] == 500
|
||||
assert "canvas_id" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_rm_exception_path(monkeypatch):
|
||||
module = _load_api_app(monkeypatch)
|
||||
|
||||
async def req_rm():
|
||||
return {"tokens": ["tok-1"], "tenant_id": "tenant-1"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", req_rm)
|
||||
monkeypatch.setattr(
|
||||
module.APITokenService,
|
||||
"filter_delete",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("delete failed")),
|
||||
)
|
||||
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == 500
|
||||
assert "delete failed" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_stats_aggregation_and_error_paths(monkeypatch):
|
||||
module = _load_api_app(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
|
||||
res = module.stats()
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={"canvas_id": "canvas-1"}))
|
||||
monkeypatch.setattr(
|
||||
module.API4ConversationService,
|
||||
"stats",
|
||||
lambda *_args, **_kwargs: [
|
||||
{
|
||||
"dt": "2026-01-01",
|
||||
"pv": 3,
|
||||
"uv": 2,
|
||||
"tokens": 100,
|
||||
"duration": 9.9,
|
||||
"round": 1,
|
||||
"thumb_up": 0,
|
||||
}
|
||||
],
|
||||
)
|
||||
res = module.stats()
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["pv"] == [("2026-01-01", 3)]
|
||||
assert res["data"]["uv"] == [("2026-01-01", 2)]
|
||||
assert res["data"]["round"] == [("2026-01-01", 1)]
|
||||
assert res["data"]["thumb_up"] == [("2026-01-01", 0)]
|
||||
assert res["data"]["tokens"] == [("2026-01-01", 0.1)]
|
||||
assert res["data"]["speed"] == [("2026-01-01", 10.0)]
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.API4ConversationService,
|
||||
"stats",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("stats failed")),
|
||||
)
|
||||
res = module.stats()
|
||||
assert res["code"] == 500
|
||||
assert "stats failed" in res["message"]
|
||||
@ -0,0 +1,484 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, payload=None, err=None):
|
||||
self._payload = payload or {}
|
||||
self._err = err
|
||||
|
||||
def raise_for_status(self):
|
||||
if self._err:
|
||||
raise self._err
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class _DummyJwkClient:
|
||||
def __init__(self, _jwks_uri):
|
||||
self._key = "dummy-signing-key"
|
||||
|
||||
def get_signing_key_from_jwt(self, _id_token):
|
||||
return SimpleNamespace(key=self._key)
|
||||
|
||||
|
||||
def _load_auth_modules(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
apps_pkg = ModuleType("api.apps")
|
||||
apps_pkg.__path__ = [str(repo_root / "api" / "apps")]
|
||||
auth_pkg = ModuleType("api.apps.auth")
|
||||
auth_pkg.__path__ = [str(repo_root / "api" / "apps" / "auth")]
|
||||
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_pkg)
|
||||
monkeypatch.setitem(sys.modules, "api.apps.auth", auth_pkg)
|
||||
|
||||
for mod_name in ["api.apps.auth.oauth", "api.apps.auth.oidc"]:
|
||||
sys.modules.pop(mod_name, None)
|
||||
|
||||
oauth_path = repo_root / "api" / "apps" / "auth" / "oauth.py"
|
||||
oauth_spec = importlib.util.spec_from_file_location("api.apps.auth.oauth", oauth_path)
|
||||
oauth_module = importlib.util.module_from_spec(oauth_spec)
|
||||
monkeypatch.setitem(sys.modules, "api.apps.auth.oauth", oauth_module)
|
||||
oauth_spec.loader.exec_module(oauth_module)
|
||||
|
||||
oidc_path = repo_root / "api" / "apps" / "auth" / "oidc.py"
|
||||
oidc_spec = importlib.util.spec_from_file_location("api.apps.auth.oidc", oidc_path)
|
||||
oidc_module = importlib.util.module_from_spec(oidc_spec)
|
||||
monkeypatch.setitem(sys.modules, "api.apps.auth.oidc", oidc_module)
|
||||
oidc_spec.loader.exec_module(oidc_module)
|
||||
|
||||
return oauth_module, oidc_module
|
||||
|
||||
|
||||
def _load_github_module(monkeypatch):
|
||||
_load_auth_modules(monkeypatch)
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
sys.modules.pop("api.apps.auth.github", None)
|
||||
github_path = repo_root / "api" / "apps" / "auth" / "github.py"
|
||||
github_spec = importlib.util.spec_from_file_location("api.apps.auth.github", github_path)
|
||||
github_module = importlib.util.module_from_spec(github_spec)
|
||||
monkeypatch.setitem(sys.modules, "api.apps.auth.github", github_module)
|
||||
github_spec.loader.exec_module(github_module)
|
||||
return github_module
|
||||
|
||||
|
||||
def _load_auth_init_module(monkeypatch):
|
||||
_load_auth_modules(monkeypatch)
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
github_mod = ModuleType("api.apps.auth.github")
|
||||
|
||||
class _StubGithubOAuthClient:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
github_mod.GithubOAuthClient = _StubGithubOAuthClient
|
||||
monkeypatch.setitem(sys.modules, "api.apps.auth.github", github_mod)
|
||||
|
||||
init_path = repo_root / "api" / "apps" / "auth" / "__init__.py"
|
||||
init_spec = importlib.util.spec_from_file_location(
|
||||
"api.apps.auth",
|
||||
init_path,
|
||||
submodule_search_locations=[str(repo_root / "api" / "apps" / "auth")],
|
||||
)
|
||||
init_module = importlib.util.module_from_spec(init_spec)
|
||||
monkeypatch.setitem(sys.modules, "api.apps.auth", init_module)
|
||||
init_spec.loader.exec_module(init_module)
|
||||
return init_module
|
||||
|
||||
|
||||
def _base_config():
|
||||
return {
|
||||
"issuer": "https://issuer.example",
|
||||
"client_id": "client-1",
|
||||
"client_secret": "secret-1",
|
||||
"redirect_uri": "https://app.example/callback",
|
||||
}
|
||||
|
||||
|
||||
def _metadata(issuer):
|
||||
return {
|
||||
"issuer": issuer,
|
||||
"jwks_uri": f"{issuer}/jwks",
|
||||
"authorization_endpoint": f"{issuer}/authorize",
|
||||
"token_endpoint": f"{issuer}/token",
|
||||
"userinfo_endpoint": f"{issuer}/userinfo",
|
||||
}
|
||||
|
||||
|
||||
def _make_client(monkeypatch, oidc_module):
|
||||
monkeypatch.setattr(oidc_module.OIDCClient, "_load_oidc_metadata", staticmethod(lambda issuer: _metadata(issuer)))
|
||||
return oidc_module.OIDCClient(_base_config())
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_oidc_init_requires_issuer(monkeypatch):
|
||||
_, oidc_module = _load_auth_modules(monkeypatch)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oidc_module.OIDCClient({"client_id": "cid"})
|
||||
|
||||
assert str(exc_info.value) == "Missing issuer in configuration."
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_oidc_init_loads_metadata_and_sets_endpoints(monkeypatch):
|
||||
_, oidc_module = _load_auth_modules(monkeypatch)
|
||||
monkeypatch.setattr(oidc_module.OIDCClient, "_load_oidc_metadata", staticmethod(lambda issuer: _metadata(issuer)))
|
||||
|
||||
client = oidc_module.OIDCClient(_base_config())
|
||||
|
||||
assert client.issuer == "https://issuer.example"
|
||||
assert client.jwks_uri == "https://issuer.example/jwks"
|
||||
assert client.authorization_url == "https://issuer.example/authorize"
|
||||
assert client.token_url == "https://issuer.example/token"
|
||||
assert client.userinfo_url == "https://issuer.example/userinfo"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_load_oidc_metadata_success_and_wraps_failure(monkeypatch):
|
||||
_, oidc_module = _load_auth_modules(monkeypatch)
|
||||
|
||||
calls = {}
|
||||
|
||||
def _ok_sync_request(method, url, timeout):
|
||||
calls.update({"method": method, "url": url, "timeout": timeout})
|
||||
return _FakeResponse(_metadata("https://issuer.example"))
|
||||
|
||||
monkeypatch.setattr(oidc_module, "sync_request", _ok_sync_request)
|
||||
metadata = oidc_module.OIDCClient._load_oidc_metadata("https://issuer.example")
|
||||
assert metadata["jwks_uri"] == "https://issuer.example/jwks"
|
||||
assert calls == {
|
||||
"method": "GET",
|
||||
"url": "https://issuer.example/.well-known/openid-configuration",
|
||||
"timeout": 7,
|
||||
}
|
||||
|
||||
def _boom_sync_request(*_args, **_kwargs):
|
||||
raise RuntimeError("metadata boom")
|
||||
|
||||
monkeypatch.setattr(oidc_module, "sync_request", _boom_sync_request)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oidc_module.OIDCClient._load_oidc_metadata("https://issuer.example")
|
||||
assert str(exc_info.value) == "Failed to fetch OIDC metadata: metadata boom"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_parse_id_token_success_and_error(monkeypatch):
|
||||
_, oidc_module = _load_auth_modules(monkeypatch)
|
||||
client = _make_client(monkeypatch, oidc_module)
|
||||
|
||||
monkeypatch.setattr(oidc_module.jwt, "get_unverified_header", lambda _token: {})
|
||||
|
||||
seen = {}
|
||||
|
||||
class _JwkClient(_DummyJwkClient):
|
||||
def __init__(self, jwks_uri):
|
||||
super().__init__(jwks_uri)
|
||||
seen["jwks_uri"] = jwks_uri
|
||||
|
||||
def get_signing_key_from_jwt(self, id_token):
|
||||
seen["id_token"] = id_token
|
||||
return super().get_signing_key_from_jwt(id_token)
|
||||
|
||||
monkeypatch.setattr(oidc_module.jwt, "PyJWKClient", _JwkClient)
|
||||
|
||||
def _decode(id_token, key, algorithms, audience, issuer):
|
||||
seen.update(
|
||||
{
|
||||
"decode_id_token": id_token,
|
||||
"decode_key": key,
|
||||
"algorithms": algorithms,
|
||||
"audience": audience,
|
||||
"issuer": issuer,
|
||||
}
|
||||
)
|
||||
return {"sub": "user-1", "email": "id@example.com"}
|
||||
|
||||
monkeypatch.setattr(oidc_module.jwt, "decode", _decode)
|
||||
parsed = client.parse_id_token("id-token-1")
|
||||
|
||||
assert parsed["sub"] == "user-1"
|
||||
assert seen["jwks_uri"] == "https://issuer.example/jwks"
|
||||
assert seen["decode_key"] == "dummy-signing-key"
|
||||
assert seen["algorithms"] == ["RS256"]
|
||||
assert seen["audience"] == "client-1"
|
||||
assert seen["issuer"] == "https://issuer.example"
|
||||
|
||||
def _raise_decode(*_args, **_kwargs):
|
||||
raise RuntimeError("decode boom")
|
||||
|
||||
monkeypatch.setattr(oidc_module.jwt, "decode", _raise_decode)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
client.parse_id_token("id-token-2")
|
||||
assert str(exc_info.value) == "Error parsing ID Token: decode boom"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_fetch_user_info_merges_id_token_and_oauth_userinfo(monkeypatch):
|
||||
oauth_module, oidc_module = _load_auth_modules(monkeypatch)
|
||||
client = _make_client(monkeypatch, oidc_module)
|
||||
|
||||
monkeypatch.setattr(
|
||||
oidc_module.OIDCClient,
|
||||
"parse_id_token",
|
||||
lambda self, _id_token: {"picture": "id-picture", "email": "id@example.com"},
|
||||
)
|
||||
|
||||
def _fake_parent_fetch(self, access_token, **_kwargs):
|
||||
assert access_token == "access-1"
|
||||
return oauth_module.UserInfo(
|
||||
email="oauth@example.com",
|
||||
username="oauth-user",
|
||||
nickname="oauth-nick",
|
||||
avatar_url=None,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(oauth_module.OAuthClient, "fetch_user_info", _fake_parent_fetch)
|
||||
|
||||
info = client.fetch_user_info("access-1", id_token="id-token")
|
||||
|
||||
assert info.email == "oauth@example.com"
|
||||
assert info.username == "oauth-user"
|
||||
assert info.nickname == "oauth-nick"
|
||||
assert info.avatar_url == "id-picture"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_async_fetch_user_info_merges_id_token_and_oauth_userinfo(monkeypatch):
|
||||
oauth_module, oidc_module = _load_auth_modules(monkeypatch)
|
||||
client = _make_client(monkeypatch, oidc_module)
|
||||
|
||||
monkeypatch.setattr(
|
||||
oidc_module.OIDCClient,
|
||||
"parse_id_token",
|
||||
lambda self, _id_token: {"picture": "id-picture-async", "email": "id-async@example.com"},
|
||||
)
|
||||
|
||||
async def _fake_parent_async_fetch(self, access_token, **_kwargs):
|
||||
assert access_token == "access-2"
|
||||
return oauth_module.UserInfo(
|
||||
email="oauth-async@example.com",
|
||||
username="oauth-async-user",
|
||||
nickname="oauth-async-nick",
|
||||
avatar_url=None,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(oauth_module.OAuthClient, "async_fetch_user_info", _fake_parent_async_fetch)
|
||||
|
||||
info = asyncio.run(client.async_fetch_user_info("access-2", id_token="id-token"))
|
||||
|
||||
assert info.email == "oauth-async@example.com"
|
||||
assert info.username == "oauth-async-user"
|
||||
assert info.nickname == "oauth-async-nick"
|
||||
assert info.avatar_url == "id-picture-async"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_normalize_user_info_passthrough(monkeypatch):
|
||||
oauth_module, oidc_module = _load_auth_modules(monkeypatch)
|
||||
client = _make_client(monkeypatch, oidc_module)
|
||||
|
||||
result = client.normalize_user_info(
|
||||
{
|
||||
"email": "user@example.com",
|
||||
"username": "user",
|
||||
"nickname": "User",
|
||||
"picture": "picture-url",
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(result, oauth_module.UserInfo)
|
||||
assert result.to_dict() == {
|
||||
"email": "user@example.com",
|
||||
"username": "user",
|
||||
"nickname": "User",
|
||||
"avatar_url": "picture-url",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_auth_client_type_inference_and_unsupported(monkeypatch):
|
||||
auth_module = _load_auth_init_module(monkeypatch)
|
||||
|
||||
class _FakeOAuth2Client:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
class _FakeOidcClient:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
class _FakeGithubClient:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
monkeypatch.setattr(
|
||||
auth_module,
|
||||
"CLIENT_TYPES",
|
||||
{
|
||||
"oauth2": _FakeOAuth2Client,
|
||||
"oidc": _FakeOidcClient,
|
||||
"github": _FakeGithubClient,
|
||||
},
|
||||
)
|
||||
|
||||
oidc_client = auth_module.get_auth_client({"issuer": "https://issuer.example"})
|
||||
assert isinstance(oidc_client, _FakeOidcClient)
|
||||
|
||||
oauth_client = auth_module.get_auth_client({})
|
||||
assert isinstance(oauth_client, _FakeOAuth2Client)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported type: invalid"):
|
||||
auth_module.get_auth_client({"type": "invalid"})
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_github_oauth_client_init_and_normalize_unit(monkeypatch):
|
||||
github_module = _load_github_module(monkeypatch)
|
||||
|
||||
client = github_module.GithubOAuthClient(_base_config())
|
||||
assert client.authorization_url == "https://github.com/login/oauth/authorize"
|
||||
assert client.token_url == "https://github.com/login/oauth/access_token"
|
||||
assert client.userinfo_url == "https://api.github.com/user"
|
||||
assert client.scope == "user:email"
|
||||
|
||||
normalized = client.normalize_user_info(
|
||||
{
|
||||
"email": "octo@example.com",
|
||||
"login": "octocat",
|
||||
"name": "Octo Cat",
|
||||
"avatar_url": "https://avatar.example/octocat.png",
|
||||
}
|
||||
)
|
||||
assert normalized.to_dict() == {
|
||||
"email": "octo@example.com",
|
||||
"username": "octocat",
|
||||
"nickname": "Octo Cat",
|
||||
"avatar_url": "https://avatar.example/octocat.png",
|
||||
}
|
||||
|
||||
normalized_fallback = client.normalize_user_info({"email": "fallback@example.com"})
|
||||
assert normalized_fallback.to_dict() == {
|
||||
"email": "fallback@example.com",
|
||||
"username": "fallback",
|
||||
"nickname": "fallback",
|
||||
"avatar_url": "",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_github_fetch_user_info_sync_success_and_error_unit(monkeypatch):
|
||||
github_module = _load_github_module(monkeypatch)
|
||||
client = github_module.GithubOAuthClient(_base_config())
|
||||
|
||||
calls = []
|
||||
|
||||
def _fake_sync_request(method, url, headers=None, timeout=None):
|
||||
calls.append((method, url, headers, timeout))
|
||||
if url.endswith("/emails"):
|
||||
return _FakeResponse(
|
||||
[
|
||||
{"email": "other@example.com", "primary": False},
|
||||
{"email": "octo@example.com", "primary": True},
|
||||
]
|
||||
)
|
||||
return _FakeResponse({"login": "octocat", "name": "Octo Cat", "avatar_url": "https://avatar.example/octocat.png"})
|
||||
|
||||
monkeypatch.setattr(github_module, "sync_request", _fake_sync_request)
|
||||
info = client.fetch_user_info("sync-token")
|
||||
|
||||
assert info.to_dict() == {
|
||||
"email": "octo@example.com",
|
||||
"username": "octocat",
|
||||
"nickname": "Octo Cat",
|
||||
"avatar_url": "https://avatar.example/octocat.png",
|
||||
}
|
||||
assert [call[1] for call in calls] == [
|
||||
"https://api.github.com/user",
|
||||
"https://api.github.com/user/emails",
|
||||
]
|
||||
assert all(call[2]["Authorization"] == "Bearer sync-token" for call in calls)
|
||||
assert all(call[3] == 7 for call in calls)
|
||||
|
||||
def _sync_request_raises(*_args, **_kwargs):
|
||||
return _FakeResponse(err=RuntimeError("status boom"))
|
||||
|
||||
monkeypatch.setattr(github_module, "sync_request", _sync_request_raises)
|
||||
with pytest.raises(ValueError, match="Failed to fetch github user info: status boom"):
|
||||
client.fetch_user_info("sync-token")
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_github_fetch_user_info_async_success_and_error_unit(monkeypatch):
|
||||
github_module = _load_github_module(monkeypatch)
|
||||
client = github_module.GithubOAuthClient(_base_config())
|
||||
|
||||
calls = []
|
||||
|
||||
async def _fake_async_request(method, url, headers=None, **kwargs):
|
||||
calls.append((method, url, headers, kwargs.get("timeout")))
|
||||
if url.endswith("/emails"):
|
||||
return _FakeResponse(
|
||||
[
|
||||
{"email": "other@example.com", "primary": False},
|
||||
{"email": "octo-async@example.com", "primary": True},
|
||||
]
|
||||
)
|
||||
return _FakeResponse(
|
||||
{"login": "octocat-async", "name": "Octo Async", "avatar_url": "https://avatar.example/octo-async.png"}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(github_module, "async_request", _fake_async_request)
|
||||
info = asyncio.run(client.async_fetch_user_info("async-token"))
|
||||
|
||||
assert info.to_dict() == {
|
||||
"email": "octo-async@example.com",
|
||||
"username": "octocat-async",
|
||||
"nickname": "Octo Async",
|
||||
"avatar_url": "https://avatar.example/octo-async.png",
|
||||
}
|
||||
assert [call[1] for call in calls] == [
|
||||
"https://api.github.com/user",
|
||||
"https://api.github.com/user/emails",
|
||||
]
|
||||
assert all(call[2]["Authorization"] == "Bearer async-token" for call in calls)
|
||||
assert all(call[3] == 7 for call in calls)
|
||||
|
||||
async def _async_request_raises(*_args, **_kwargs):
|
||||
return _FakeResponse(err=RuntimeError("async status boom"))
|
||||
|
||||
monkeypatch.setattr(github_module, "async_request", _async_request_raises)
|
||||
with pytest.raises(ValueError, match="Failed to fetch github user info: async status boom"):
|
||||
asyncio.run(client.async_fetch_user_info("async-token"))
|
||||
@ -0,0 +1,669 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _Vec(list):
|
||||
def __mul__(self, scalar):
|
||||
return _Vec([scalar * x for x in self])
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
||||
def __add__(self, other):
|
||||
return _Vec([a + b for a, b in zip(self, other)])
|
||||
|
||||
def tolist(self):
|
||||
return list(self)
|
||||
|
||||
|
||||
class _DummyDoc:
|
||||
def __init__(self, *, doc_id="doc-1", kb_id="kb-1", name="Doc", parser_id="naive"):
|
||||
self.id = doc_id
|
||||
self.kb_id = kb_id
|
||||
self.name = name
|
||||
self.parser_id = parser_id
|
||||
|
||||
def to_dict(self):
|
||||
return {"id": self.id, "kb_id": self.kb_id, "name": self.name}
|
||||
|
||||
|
||||
class _DummyRetCode:
|
||||
SUCCESS = 0
|
||||
DATA_ERROR = 102
|
||||
EXCEPTION_ERROR = 100
|
||||
|
||||
|
||||
class _DummyParserType:
|
||||
QA = "qa"
|
||||
NAIVE = "naive"
|
||||
|
||||
|
||||
class _DummyRetriever:
|
||||
async def search(self, query, _index_name, _kb_ids, highlight=None):
|
||||
class _SRes:
|
||||
total = 1
|
||||
ids = ["chunk-1"]
|
||||
field = {
|
||||
"chunk-1": {
|
||||
"content_with_weight": "chunk content",
|
||||
"doc_id": "doc-1",
|
||||
"docnm_kwd": "Doc",
|
||||
"important_kwd": ["k1"],
|
||||
"question_kwd": ["q1"],
|
||||
"img_id": "img-1",
|
||||
"available_int": 1,
|
||||
"position_int": [],
|
||||
"doc_type_kwd": "text",
|
||||
}
|
||||
}
|
||||
highlight = {"chunk-1": " highlighted content "}
|
||||
|
||||
_ = (query, highlight)
|
||||
return _SRes()
|
||||
|
||||
|
||||
class _DummyDocStore:
|
||||
def __init__(self):
|
||||
self.updated = []
|
||||
self.inserted = []
|
||||
self.deleted_inputs = []
|
||||
self.to_delete = [1]
|
||||
self.chunk = {
|
||||
"id": "chunk-1",
|
||||
"doc_id": "doc-1",
|
||||
"kb_id": "kb-1",
|
||||
"content_with_weight": "chunk content",
|
||||
"docnm_kwd": "Doc",
|
||||
"q_2_vec": [0.1, 0.2],
|
||||
"content_tks": ["a"],
|
||||
"content_ltks": ["b"],
|
||||
"content_sm_ltks": ["c"],
|
||||
}
|
||||
|
||||
def get(self, *_args, **_kwargs):
|
||||
return dict(self.chunk) if self.chunk is not None else None
|
||||
|
||||
def update(self, condition, payload, *_args, **_kwargs):
|
||||
self.updated.append((condition, payload))
|
||||
return True
|
||||
|
||||
def delete(self, condition, *_args, **_kwargs):
|
||||
self.deleted_inputs.append(condition)
|
||||
if not self.to_delete:
|
||||
return 0
|
||||
return self.to_delete.pop(0)
|
||||
|
||||
def insert(self, docs, *_args, **_kwargs):
|
||||
self.inserted.extend(docs)
|
||||
|
||||
|
||||
class _DummyStorage:
|
||||
def __init__(self):
|
||||
self.put_calls = []
|
||||
self.rm_calls = []
|
||||
|
||||
def put(self, bucket, name, binary):
|
||||
self.put_calls.append((bucket, name, binary))
|
||||
|
||||
def obj_exist(self, _bucket, _name):
|
||||
return True
|
||||
|
||||
def rm(self, bucket, name):
|
||||
self.rm_calls.append((bucket, name))
|
||||
|
||||
|
||||
class _DummyTenant:
|
||||
def __init__(self, tenant_id="tenant-1"):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
|
||||
class _DummyLLMBundle:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def encode(self, _inputs):
|
||||
return [_Vec([1.0, 2.0]), _Vec([3.0, 4.0])], 9
|
||||
|
||||
|
||||
class _DummyXXHash:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
|
||||
def hexdigest(self):
|
||||
return f"chunk-{len(self._data)}"
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _load_chunk_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = SimpleNamespace(args={}, headers={})
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
xxhash_mod = ModuleType("xxhash")
|
||||
xxhash_mod.xxh64 = lambda data: _DummyXXHash(data)
|
||||
monkeypatch.setitem(sys.modules, "xxhash", xxhash_mod)
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
settings_mod = ModuleType("common.settings")
|
||||
settings_mod.retriever = _DummyRetriever()
|
||||
settings_mod.docStoreConn = _DummyDocStore()
|
||||
settings_mod.STORAGE_IMPL = _DummyStorage()
|
||||
monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
|
||||
common_pkg.settings = settings_mod
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
|
||||
class _DummyLLMType:
|
||||
EMBEDDING = SimpleNamespace(value="embedding")
|
||||
CHAT = SimpleNamespace(value="chat")
|
||||
|
||||
constants_mod.RetCode = _DummyRetCode
|
||||
constants_mod.LLMType = _DummyLLMType
|
||||
constants_mod.ParserType = _DummyParserType
|
||||
constants_mod.PAGERANK_FLD = "pagerank_flt"
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
string_utils_mod = ModuleType("common.string_utils")
|
||||
string_utils_mod.remove_redundant_spaces = lambda text: " ".join(str(text).split())
|
||||
monkeypatch.setitem(sys.modules, "common.string_utils", string_utils_mod)
|
||||
|
||||
metadata_utils_mod = ModuleType("common.metadata_utils")
|
||||
metadata_utils_mod.apply_meta_data_filter = lambda *_args, **_kwargs: {}
|
||||
monkeypatch.setitem(sys.modules, "common.metadata_utils", metadata_utils_mod)
|
||||
|
||||
misc_utils_mod = ModuleType("common.misc_utils")
|
||||
|
||||
async def _thread_pool_exec(func):
|
||||
return func()
|
||||
|
||||
misc_utils_mod.thread_pool_exec = _thread_pool_exec
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
|
||||
|
||||
rag_pkg = ModuleType("rag")
|
||||
rag_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "rag", rag_pkg)
|
||||
|
||||
rag_app_pkg = ModuleType("rag.app")
|
||||
rag_app_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "rag.app", rag_app_pkg)
|
||||
|
||||
rag_qa_mod = ModuleType("rag.app.qa")
|
||||
rag_qa_mod.rmPrefix = lambda text: str(text).strip("Q: ").strip("A: ")
|
||||
rag_qa_mod.beAdoc = lambda d, q, a, _latin: {**d, "question_kwd": [q], "content_with_weight": f"{q}\n{a}"}
|
||||
monkeypatch.setitem(sys.modules, "rag.app.qa", rag_qa_mod)
|
||||
|
||||
rag_tag_mod = ModuleType("rag.app.tag")
|
||||
rag_tag_mod.label_question = lambda *_args, **_kwargs: []
|
||||
monkeypatch.setitem(sys.modules, "rag.app.tag", rag_tag_mod)
|
||||
|
||||
rag_nlp_mod = ModuleType("rag.nlp")
|
||||
rag_nlp_mod.rag_tokenizer = SimpleNamespace(
|
||||
tokenize=lambda text: [str(text)],
|
||||
fine_grained_tokenize=lambda toks: [f"fg:{t}" for t in toks],
|
||||
is_chinese=lambda _text: False,
|
||||
)
|
||||
rag_nlp_mod.search = SimpleNamespace(index_name=lambda tenant_id: f"idx-{tenant_id}")
|
||||
monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp_mod)
|
||||
|
||||
rag_prompts_pkg = ModuleType("rag.prompts")
|
||||
rag_prompts_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg)
|
||||
|
||||
rag_generator_mod = ModuleType("rag.prompts.generator")
|
||||
rag_generator_mod.cross_languages = lambda *_args, **_kwargs: []
|
||||
rag_generator_mod.keyword_extraction = lambda *_args, **_kwargs: []
|
||||
monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_generator_mod)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="user-1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "message": message, "data": data}
|
||||
api_utils_mod.get_data_error_result = lambda message="": {"code": _DummyRetCode.DATA_ERROR, "message": message, "data": False}
|
||||
api_utils_mod.server_error_response = lambda exc: {"code": _DummyRetCode.EXCEPTION_ERROR, "message": repr(exc), "data": False}
|
||||
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
|
||||
api_utils_mod.get_request_json = lambda: _AwaitableValue({})
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
services_pkg = ModuleType("api.db.services")
|
||||
services_pkg.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_pkg)
|
||||
|
||||
document_service_mod = ModuleType("api.db.services.document_service")
|
||||
|
||||
class _DocumentService:
|
||||
decrement_calls = []
|
||||
increment_calls = []
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_id(_doc_id):
|
||||
return "tenant-1"
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(doc_id):
|
||||
return True, _DummyDoc(doc_id=doc_id, parser_id=_DummyParserType.NAIVE)
|
||||
|
||||
@staticmethod
|
||||
def get_embd_id(_doc_id):
|
||||
return "embed-1"
|
||||
|
||||
@staticmethod
|
||||
def decrement_chunk_num(*args):
|
||||
_DocumentService.decrement_calls.append(args)
|
||||
|
||||
@staticmethod
|
||||
def increment_chunk_num(*args):
|
||||
_DocumentService.increment_calls.append(args)
|
||||
|
||||
document_service_mod.DocumentService = _DocumentService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod)
|
||||
services_pkg.document_service = document_service_mod
|
||||
|
||||
doc_metadata_service_mod = ModuleType("api.db.services.doc_metadata_service")
|
||||
doc_metadata_service_mod.DocMetadataService = type("DocMetadataService", (), {})
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.doc_metadata_service", doc_metadata_service_mod)
|
||||
services_pkg.doc_metadata_service = doc_metadata_service_mod
|
||||
|
||||
kb_service_mod = ModuleType("api.db.services.knowledgebase_service")
|
||||
|
||||
class _KnowledgebaseService:
|
||||
@staticmethod
|
||||
def get_kb_ids(_tenant_id):
|
||||
return ["kb-1"]
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(_kb_id):
|
||||
return True, SimpleNamespace(pagerank=0.6)
|
||||
|
||||
kb_service_mod.KnowledgebaseService = _KnowledgebaseService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod)
|
||||
services_pkg.knowledgebase_service = kb_service_mod
|
||||
|
||||
llm_service_mod = ModuleType("api.db.services.llm_service")
|
||||
llm_service_mod.LLMBundle = _DummyLLMBundle
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
|
||||
services_pkg.llm_service = llm_service_mod
|
||||
|
||||
search_service_mod = ModuleType("api.db.services.search_service")
|
||||
search_service_mod.SearchService = type("SearchService", (), {})
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod)
|
||||
services_pkg.search_service = search_service_mod
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
|
||||
class _UserTenantService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return [_DummyTenant("tenant-1")]
|
||||
|
||||
user_service_mod.UserTenantService = _UserTenantService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
services_pkg.user_service = user_service_mod
|
||||
|
||||
module_name = "test_chunk_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "chunk_app.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_chunk_exception_branches_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "keywords": "chunk", "available_int": 0})
|
||||
res = _run(module.list_chunk())
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["total"] == 1, res
|
||||
assert res["data"]["chunks"][0]["available_int"] == 1, res
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "")
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1"})
|
||||
res = _run(module.list_chunk())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert res["message"] == "Tenant not found!", res
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1"})
|
||||
res = _run(module.list_chunk())
|
||||
assert res["message"] == "Document not found!", res
|
||||
|
||||
async def _raise_not_found(*_args, **_kwargs):
|
||||
raise Exception("x not_found y")
|
||||
|
||||
monkeypatch.setattr(module.settings.retriever, "search", _raise_not_found)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, _DummyDoc()))
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1"})
|
||||
res = _run(module.list_chunk())
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert res["message"] == "No chunk found!", res
|
||||
|
||||
async def _raise_generic(*_args, **_kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(module.settings.retriever, "search", _raise_generic)
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1"})
|
||||
res = _run(module.list_chunk())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
|
||||
assert "boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_chunk_sanitize_and_exception_matrix_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
module.request = SimpleNamespace(args={"chunk_id": "chunk-1"}, headers={})
|
||||
|
||||
res = module.get()
|
||||
assert res["code"] == 0, res
|
||||
assert "q_2_vec" not in res["data"], res
|
||||
assert "content_tks" not in res["data"], res
|
||||
assert "content_ltks" not in res["data"], res
|
||||
assert "content_sm_ltks" not in res["data"], res
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [])
|
||||
res = module.get()
|
||||
assert res["message"] == "Tenant not found!", res
|
||||
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [_DummyTenant("tenant-1")])
|
||||
module.settings.docStoreConn.chunk = None
|
||||
res = module.get()
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
|
||||
assert "Chunk not found" in res["message"], res
|
||||
|
||||
def _raise_not_found(*_args, **_kwargs):
|
||||
raise Exception("NotFoundError: chunk-1")
|
||||
|
||||
monkeypatch.setattr(module.settings.docStoreConn, "get", _raise_not_found)
|
||||
res = module.get()
|
||||
assert res["code"] == module.RetCode.DATA_ERROR, res
|
||||
assert res["message"] == "Chunk not found!", res
|
||||
|
||||
def _raise_generic(*_args, **_kwargs):
|
||||
raise RuntimeError("get boom")
|
||||
|
||||
monkeypatch.setattr(module.settings.docStoreConn, "get", _raise_generic)
|
||||
res = module.get()
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
|
||||
assert "get boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_set_chunk_bytes_qa_image_and_guard_matrix_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": 1})
|
||||
with pytest.raises(TypeError, match="expected string or bytes-like object"):
|
||||
_run(module.set())
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc", "important_kwd": "bad"},
|
||||
)
|
||||
res = _run(module.set())
|
||||
assert res["message"] == "`important_kwd` should be a list", res
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc", "question_kwd": "bad"},
|
||||
)
|
||||
res = _run(module.set())
|
||||
assert res["message"] == "`question_kwd` should be a list", res
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "")
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc"})
|
||||
res = _run(module.set())
|
||||
assert res["message"] == "Tenant not found!", res
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc"})
|
||||
res = _run(module.set())
|
||||
assert res["message"] == "Document not found!", res
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.DocumentService,
|
||||
"get_by_id",
|
||||
lambda _doc_id: (True, _DummyDoc(doc_id="doc-1", parser_id=module.ParserType.NAIVE)),
|
||||
)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"doc_id": "doc-1",
|
||||
"chunk_id": "chunk-1",
|
||||
"content_with_weight": b"bytes-content",
|
||||
"important_kwd": ["important"],
|
||||
"question_kwd": ["question"],
|
||||
"tag_kwd": ["tag"],
|
||||
"tag_feas": [0.1],
|
||||
"available_int": 0,
|
||||
},
|
||||
)
|
||||
res = _run(module.set())
|
||||
assert res["code"] == 0, res
|
||||
assert module.settings.docStoreConn.updated[-1][1]["content_with_weight"] == "bytes-content"
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.DocumentService,
|
||||
"get_by_id",
|
||||
lambda _doc_id: (True, _DummyDoc(doc_id="doc-1", parser_id=module.ParserType.QA)),
|
||||
)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"doc_id": "doc-1",
|
||||
"chunk_id": "chunk-2",
|
||||
"content_with_weight": "Q:Question\nA:Answer",
|
||||
"image_base64": base64.b64encode(b"image").decode("utf-8"),
|
||||
"img_id": "bucket-name",
|
||||
},
|
||||
)
|
||||
res = _run(module.set())
|
||||
assert res["code"] == 0, res
|
||||
assert module.settings.STORAGE_IMPL.put_calls, "image storage branch should be called"
|
||||
|
||||
async def _raise_thread_pool(_func):
|
||||
raise RuntimeError("set tp boom")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _raise_thread_pool)
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_id": "chunk-1", "content_with_weight": "abc"})
|
||||
res = _run(module.set())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
|
||||
assert "set tp boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_switch_chunk_success_failure_and_exception_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"], "available_int": 1})
|
||||
res = _run(module.switch())
|
||||
assert res["message"] == "Document not found!", res
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, _DummyDoc()))
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
|
||||
monkeypatch.setattr(module.settings.docStoreConn, "update", lambda *_args, **_kwargs: False)
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1", "c2"], "available_int": 0})
|
||||
res = _run(module.switch())
|
||||
assert res["message"] == "Index updating failure", res
|
||||
|
||||
monkeypatch.setattr(module.settings.docStoreConn, "update", lambda *_args, **_kwargs: True)
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1", "c2"], "available_int": 1})
|
||||
res = _run(module.switch())
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"] is True, res
|
||||
|
||||
async def _raise_thread_pool(_func):
|
||||
raise RuntimeError("switch tp boom")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _raise_thread_pool)
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"], "available_int": 1})
|
||||
res = _run(module.switch())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
|
||||
assert "switch tp boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_rm_chunk_delete_exception_partial_compensation_and_cleanup_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"]})
|
||||
res = _run(module.rm())
|
||||
assert res["message"] == "Document not found!", res
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, _DummyDoc()))
|
||||
|
||||
def _raise_delete(*_args, **_kwargs):
|
||||
raise RuntimeError("delete boom")
|
||||
|
||||
monkeypatch.setattr(module.settings.docStoreConn, "delete", _raise_delete)
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"]})
|
||||
res = _run(module.rm())
|
||||
assert res["message"] == "Chunk deleting failure", res
|
||||
|
||||
def _delete(condition, *_args, **_kwargs):
|
||||
module.settings.docStoreConn.deleted_inputs.append(condition)
|
||||
if not module.settings.docStoreConn.to_delete:
|
||||
return 0
|
||||
return module.settings.docStoreConn.to_delete.pop(0)
|
||||
|
||||
module.settings.docStoreConn.to_delete = [0]
|
||||
monkeypatch.setattr(module.settings.docStoreConn, "delete", _delete)
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"]})
|
||||
res = _run(module.rm())
|
||||
assert res["message"] == "Index updating failure", res
|
||||
|
||||
module.settings.docStoreConn.to_delete = [1, 2]
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1", "c2", "c3"]})
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == 0, res
|
||||
assert module.DocumentService.decrement_calls, "decrement_chunk_num should be called"
|
||||
assert len(module.settings.STORAGE_IMPL.rm_calls) >= 1
|
||||
|
||||
module.settings.docStoreConn.to_delete = [1]
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": "c1"})
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == 0, res
|
||||
|
||||
async def _raise_thread_pool(_func):
|
||||
raise RuntimeError("rm tp boom")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _raise_thread_pool)
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "chunk_ids": ["c1"]})
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR, res
|
||||
assert "rm tp boom" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_chunk_guards_pagerank_and_success_unit(monkeypatch):
|
||||
module = _load_chunk_module(monkeypatch)
|
||||
module.request = SimpleNamespace(headers={"X-Request-ID": "req-1"}, args={})
|
||||
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk", "important_kwd": "bad"})
|
||||
res = _run(module.create())
|
||||
assert res["message"] == "`important_kwd` is required to be a list", res
|
||||
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk", "question_kwd": "bad"})
|
||||
res = _run(module.create())
|
||||
assert res["message"] == "`question_kwd` is required to be a list", res
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (False, None))
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk"})
|
||||
res = _run(module.create())
|
||||
assert res["message"] == "Document not found!", res
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_id", lambda _doc_id: (True, _DummyDoc(doc_id="doc-1")))
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "")
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk"})
|
||||
res = _run(module.create())
|
||||
assert res["message"] == "Tenant not found!", res
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_tenant_id", lambda _doc_id: "tenant-1")
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
_set_request_json(monkeypatch, module, {"doc_id": "doc-1", "content_with_weight": "chunk"})
|
||||
res = _run(module.create())
|
||||
assert res["message"] == "Knowledgebase not found!", res
|
||||
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(pagerank=0.8)))
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"doc_id": "doc-1",
|
||||
"content_with_weight": "chunk",
|
||||
"important_kwd": ["i1"],
|
||||
"question_kwd": ["q1"],
|
||||
"tag_feas": [0.2],
|
||||
},
|
||||
)
|
||||
res = _run(module.create())
|
||||
assert res["code"] == 0, res
|
||||
assert module.settings.docStoreConn.inserted, "insert should be called"
|
||||
inserted = module.settings.docStoreConn.inserted[-1]
|
||||
assert "pagerank_flt" in inserted
|
||||
assert module.DocumentService.increment_calls, "increment_chunk_num should be called"
|
||||
@ -148,6 +148,35 @@ class TestAddChunk:
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_chunk_not_found(self, WebApiAuth):
|
||||
res = get_chunk(WebApiAuth, {"chunk_id": "missing_chunk_id"})
|
||||
assert res["code"] != 0, res
|
||||
assert "Chunk not found" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_chunk_with_tag_fields(self, WebApiAuth, add_document):
|
||||
_, doc_id = add_document
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
if res["code"] == 0:
|
||||
chunks_count = res["data"]["doc"]["chunk_num"]
|
||||
else:
|
||||
chunks_count = 0
|
||||
|
||||
payload = {
|
||||
"doc_id": doc_id,
|
||||
"content_with_weight": "chunk with tags",
|
||||
"tag_feas": [0.1, 0.2],
|
||||
"important_kwd": ["tag"],
|
||||
"question_kwd": ["question"],
|
||||
}
|
||||
res = add_chunk(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["chunk_id"], res
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"doc_id, expected_code, expected_message",
|
||||
|
||||
@ -17,7 +17,7 @@ import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import batch_add_chunks, list_chunks
|
||||
from common import batch_add_chunks, list_chunks, update_chunk
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
@ -88,6 +88,33 @@ class TestChunksList:
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_available_int_filter(self, WebApiAuth, add_chunks):
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
chunk_id = chunk_ids[0]
|
||||
|
||||
res = update_chunk(
|
||||
WebApiAuth,
|
||||
{"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content", "available_int": 0},
|
||||
)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
from time import sleep
|
||||
|
||||
sleep(1)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id, "available_int": 0})
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["chunks"]) >= 1, res
|
||||
assert all(chunk["available_int"] == 0 for chunk in res["data"]["chunks"]), res
|
||||
|
||||
# Restore the class-scoped fixture state for subsequent keyword cases.
|
||||
res = update_chunk(
|
||||
WebApiAuth,
|
||||
{"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "chunk test 0", "available_int": 1},
|
||||
)
|
||||
assert res["code"] == 0, res
|
||||
sleep(1)
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_page_size",
|
||||
|
||||
@ -95,6 +95,30 @@ class TestChunksDeletion:
|
||||
assert len(res["data"]["chunks"]) == 0, res
|
||||
assert res["data"]["total"] == 0, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_delete_scalar_chunk_id_payload(self, WebApiAuth, add_chunks_func):
|
||||
_, doc_id, chunk_ids = add_chunks_func
|
||||
payload = {"chunk_ids": chunk_ids[0], "doc_id": doc_id}
|
||||
res = delete_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["chunks"]) == 3, res
|
||||
assert res["data"]["total"] == 3, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_delete_duplicate_ids_dedup_behavior(self, WebApiAuth, add_chunks_func):
|
||||
_, doc_id, chunk_ids = add_chunks_func
|
||||
payload = {"chunk_ids": [chunk_ids[0], chunk_ids[0]], "doc_id": doc_id}
|
||||
res = delete_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["chunks"]) == 3, res
|
||||
assert res["data"]["total"] == 3, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_deletion(self, WebApiAuth, add_document):
|
||||
count = 100
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import base64
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from random import randint
|
||||
@ -154,6 +155,32 @@ class TestUpdateChunk:
|
||||
if chunk["chunk_id"] == chunk_id:
|
||||
assert chunk["available_int"] == payload["available_int"]
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_chunk_qa_multiline_content(self, WebApiAuth, add_chunks):
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
payload = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "Question line\nAnswer line"}
|
||||
res = update_chunk(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
sleep(1)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
chunk = next(chunk for chunk in res["data"]["chunks"] if chunk["chunk_id"] == chunk_ids[0])
|
||||
assert chunk["content_with_weight"] == payload["content_with_weight"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_chunk_with_image_payload(self, WebApiAuth, add_chunks):
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
payload = {
|
||||
"doc_id": doc_id,
|
||||
"chunk_id": chunk_ids[0],
|
||||
"content_with_weight": "content with image",
|
||||
"image_base64": base64.b64encode(b"img").decode("utf-8"),
|
||||
"img_id": "bucket-name",
|
||||
}
|
||||
res = update_chunk(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"doc_id_param, expected_code, expected_message",
|
||||
|
||||
@ -0,0 +1,219 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _DummyAtomic:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, _exc_type, _exc, _tb):
|
||||
return False
|
||||
|
||||
|
||||
class _FakeApiError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _FakeLangfuseClient:
|
||||
def __init__(self, *, auth_result=True, auth_exc=None, project_payload=None):
|
||||
self._auth_result = auth_result
|
||||
self._auth_exc = auth_exc
|
||||
if project_payload is None:
|
||||
project_payload = {"data": [{"id": "project-id", "name": "project-name"}]}
|
||||
self.api = SimpleNamespace(
|
||||
projects=SimpleNamespace(get=lambda: SimpleNamespace(dict=lambda: project_payload)),
|
||||
core=SimpleNamespace(api_error=SimpleNamespace(ApiError=_FakeApiError)),
|
||||
)
|
||||
|
||||
def auth_check(self):
|
||||
if self._auth_exc is not None:
|
||||
raise self._auth_exc
|
||||
return self._auth_result
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _load_langfuse_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
stub_apps = ModuleType("api.apps")
|
||||
stub_apps.current_user = SimpleNamespace(id="tenant-1")
|
||||
stub_apps.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", stub_apps)
|
||||
|
||||
stub_langfuse = ModuleType("langfuse")
|
||||
stub_langfuse.Langfuse = _FakeLangfuseClient
|
||||
monkeypatch.setitem(sys.modules, "langfuse", stub_langfuse)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "langfuse_app.py"
|
||||
spec = importlib.util.spec_from_file_location("test_langfuse_app_unit", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_set_api_key_missing_fields_and_invalid_auth(monkeypatch):
|
||||
module = _load_langfuse_app(monkeypatch)
|
||||
monkeypatch.setattr(module.DB, "atomic", lambda: _DummyAtomic())
|
||||
|
||||
async def missing_fields():
|
||||
return {"secret_key": "", "public_key": "pub", "host": "http://host"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", missing_fields)
|
||||
res = _run(module.set_api_key.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "Missing required fields"
|
||||
|
||||
async def invalid_auth():
|
||||
return {"secret_key": "sec", "public_key": "pub", "host": "http://host"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", invalid_auth)
|
||||
monkeypatch.setattr(module, "Langfuse", lambda **_kwargs: _FakeLangfuseClient(auth_result=False))
|
||||
res = _run(module.set_api_key.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "Invalid Langfuse keys"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_set_api_key_create_update_and_atomic_exception(monkeypatch):
|
||||
module = _load_langfuse_app(monkeypatch)
|
||||
monkeypatch.setattr(module.DB, "atomic", lambda: _DummyAtomic())
|
||||
monkeypatch.setattr(module, "Langfuse", lambda **_kwargs: _FakeLangfuseClient(auth_result=True))
|
||||
|
||||
async def payload():
|
||||
return {"secret_key": "sec", "public_key": "pub", "host": "http://host"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", payload)
|
||||
|
||||
calls = {"save": 0, "update": 0}
|
||||
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
module.TenantLangfuseService,
|
||||
"save",
|
||||
lambda **_kwargs: calls.__setitem__("save", calls["save"] + 1),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.TenantLangfuseService,
|
||||
"update_by_tenant",
|
||||
lambda **_kwargs: calls.__setitem__("update", calls["update"] + 1),
|
||||
)
|
||||
res = _run(module.set_api_key.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert calls["save"] == 1
|
||||
|
||||
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: {"id": "existing"})
|
||||
res = _run(module.set_api_key.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert calls["update"] == 1
|
||||
|
||||
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: None)
|
||||
|
||||
def raise_save(**_kwargs):
|
||||
raise RuntimeError("save failed")
|
||||
|
||||
monkeypatch.setattr(module.TenantLangfuseService, "save", raise_save)
|
||||
res = _run(module.set_api_key.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "save failed" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_api_key_no_record_invalid_auth_api_error_generic_error_success(monkeypatch):
|
||||
module = _load_langfuse_app(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant_with_info", lambda **_kwargs: None)
|
||||
res = module.get_api_key.__wrapped__()
|
||||
assert res["code"] == 0
|
||||
assert res["message"] == "Have not record any Langfuse keys."
|
||||
|
||||
base_entry = {"secret_key": "sec", "public_key": "pub", "host": "http://host"}
|
||||
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant_with_info", lambda **_kwargs: dict(base_entry))
|
||||
monkeypatch.setattr(module, "Langfuse", lambda **_kwargs: _FakeLangfuseClient(auth_result=False))
|
||||
res = module.get_api_key.__wrapped__()
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "Invalid Langfuse keys loaded"
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"Langfuse",
|
||||
lambda **_kwargs: _FakeLangfuseClient(auth_exc=_FakeApiError("api exploded")),
|
||||
)
|
||||
res = module.get_api_key.__wrapped__()
|
||||
assert res["code"] == 0
|
||||
assert "Error from Langfuse" in res["message"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"Langfuse",
|
||||
lambda **_kwargs: _FakeLangfuseClient(auth_exc=RuntimeError("generic exploded")),
|
||||
)
|
||||
res = module.get_api_key.__wrapped__()
|
||||
assert res["code"] == 100
|
||||
assert "generic exploded" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "Langfuse", lambda **_kwargs: _FakeLangfuseClient(auth_result=True))
|
||||
res = module.get_api_key.__wrapped__()
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["project_id"] == "project-id"
|
||||
assert res["data"]["project_name"] == "project-name"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_delete_api_key_no_record_success_exception(monkeypatch):
|
||||
module = _load_langfuse_app(monkeypatch)
|
||||
monkeypatch.setattr(module.DB, "atomic", lambda: _DummyAtomic())
|
||||
|
||||
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: None)
|
||||
res = module.delete_api_key.__wrapped__()
|
||||
assert res["code"] == 0
|
||||
assert res["message"] == "Have not record any Langfuse keys."
|
||||
|
||||
monkeypatch.setattr(module.TenantLangfuseService, "filter_by_tenant", lambda **_kwargs: {"id": "entry"})
|
||||
monkeypatch.setattr(module.TenantLangfuseService, "delete_model", lambda _entry: None)
|
||||
res = module.delete_api_key.__wrapped__()
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
|
||||
def raise_delete(_entry):
|
||||
raise RuntimeError("delete failed")
|
||||
|
||||
monkeypatch.setattr(module.TenantLangfuseService, "delete_model", raise_delete)
|
||||
res = module.delete_api_key.__wrapped__()
|
||||
assert res["code"] == 100
|
||||
assert "delete failed" in res["message"]
|
||||
@ -0,0 +1,584 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from anyio import Path as AsyncPath
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _DummyRequest:
|
||||
def __init__(self, *, args=None, headers=None, form=None, files=None):
|
||||
self.args = args or {}
|
||||
self.headers = headers or {}
|
||||
self.form = _AwaitableValue(form or {})
|
||||
self.files = _AwaitableValue(files or {})
|
||||
self.method = "POST"
|
||||
self.content_length = 0
|
||||
|
||||
|
||||
class _DummyConversation:
|
||||
def __init__(self, *, conv_id="conv-1", dialog_id="dialog-1", message=None, reference=None):
|
||||
self.id = conv_id
|
||||
self.dialog_id = dialog_id
|
||||
self.message = message if message is not None else []
|
||||
self.reference = reference if reference is not None else []
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"dialog_id": self.dialog_id,
|
||||
"message": deepcopy(self.message),
|
||||
"reference": deepcopy(self.reference),
|
||||
}
|
||||
|
||||
|
||||
class _DummyDialog:
|
||||
def __init__(self, *, dialog_id="dialog-1", tenant_id="tenant-1", icon="avatar.png"):
|
||||
self.id = dialog_id
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.prompt_config = {"prologue": "hello"}
|
||||
self.llm_id = ""
|
||||
self.llm_setting = {}
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"icon": self.icon,
|
||||
"tenant_id": self.tenant_id,
|
||||
"prompt_config": deepcopy(self.prompt_config),
|
||||
}
|
||||
|
||||
|
||||
class _DummyUploadedFile:
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.saved_path = None
|
||||
|
||||
async def save(self, path):
|
||||
self.saved_path = path
|
||||
await AsyncPath(path).write_bytes(b"audio-bytes")
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _load_conversation_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
deepdoc_pkg = ModuleType("deepdoc")
|
||||
deepdoc_parser_pkg = ModuleType("deepdoc.parser")
|
||||
deepdoc_parser_pkg.__path__ = []
|
||||
|
||||
class _StubPdfParser:
|
||||
pass
|
||||
|
||||
class _StubExcelParser:
|
||||
pass
|
||||
|
||||
class _StubDocxParser:
|
||||
pass
|
||||
|
||||
deepdoc_parser_pkg.PdfParser = _StubPdfParser
|
||||
deepdoc_parser_pkg.ExcelParser = _StubExcelParser
|
||||
deepdoc_parser_pkg.DocxParser = _StubDocxParser
|
||||
deepdoc_pkg.parser = deepdoc_parser_pkg
|
||||
monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg)
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg)
|
||||
|
||||
deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser")
|
||||
deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module)
|
||||
|
||||
deepdoc_parser_utils = ModuleType("deepdoc.parser.utils")
|
||||
deepdoc_parser_utils.get_text = lambda *_args, **_kwargs: ""
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils)
|
||||
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.current_user = SimpleNamespace(id="user-1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
module_name = "test_conversation_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "conversation_app.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload)))
|
||||
|
||||
|
||||
async def _read_sse_text(response):
|
||||
chunks = []
|
||||
async for chunk in response.response:
|
||||
if isinstance(chunk, bytes):
|
||||
chunks.append(chunk.decode("utf-8"))
|
||||
else:
|
||||
chunks.append(chunk)
|
||||
return "".join(chunks)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_set_conversation_update_create_and_errors(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
|
||||
long_name = "n" * 300
|
||||
create_payload = {
|
||||
"conversation_id": "conv-new",
|
||||
"dialog_id": "dialog-1",
|
||||
"is_new": True,
|
||||
"name": long_name,
|
||||
}
|
||||
_set_request_json(monkeypatch, module, create_payload)
|
||||
|
||||
saved = {}
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialog()))
|
||||
monkeypatch.setattr(module.ConversationService, "save", lambda **kwargs: saved.update(kwargs) or True)
|
||||
res = _run(module.set_conversation())
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]["name"]) == 255
|
||||
assert saved["user_id"] == "user-1"
|
||||
|
||||
update_payload = {
|
||||
"conversation_id": "conv-1",
|
||||
"dialog_id": "dialog-1",
|
||||
"is_new": False,
|
||||
"name": "rename",
|
||||
}
|
||||
_set_request_json(monkeypatch, module, update_payload)
|
||||
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.set_conversation())
|
||||
assert "Conversation not found" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, update_payload)
|
||||
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
|
||||
res = _run(module.set_conversation())
|
||||
assert "Fail to update" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, update_payload)
|
||||
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, _DummyConversation(conv_id="conv-1")))
|
||||
res = _run(module.set_conversation())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "conv-1"
|
||||
|
||||
_set_request_json(monkeypatch, module, update_payload)
|
||||
|
||||
def _raise_update(*_args, **_kwargs):
|
||||
raise RuntimeError("update boom")
|
||||
|
||||
monkeypatch.setattr(module.ConversationService, "update_by_id", _raise_update)
|
||||
res = _run(module.set_conversation())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "update boom" in res["message"]
|
||||
|
||||
missing_dialog_payload = {
|
||||
"conversation_id": "conv-2",
|
||||
"dialog_id": "dialog-missing",
|
||||
"is_new": True,
|
||||
"name": "create",
|
||||
}
|
||||
_set_request_json(monkeypatch, module, missing_dialog_payload)
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None))
|
||||
res = _run(module.set_conversation())
|
||||
assert res["message"] == "Dialog not found"
|
||||
|
||||
_set_request_json(monkeypatch, module, missing_dialog_payload)
|
||||
|
||||
def _raise_dialog(_id):
|
||||
raise RuntimeError("dialog boom")
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", _raise_dialog)
|
||||
res = _run(module.set_conversation())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "dialog boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_and_getsse_authorization_and_reference_paths(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
|
||||
conv = _DummyConversation(reference=[{"doc": "d"}, ["already-formatted"]])
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(args={"conversation_id": "conv-1"}))
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(icon="bot-avatar")])
|
||||
monkeypatch.setattr(module, "chunks_format", lambda _ref: [{"chunk": "normalized"}])
|
||||
|
||||
res = _run(module.get())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["avatar"] == "bot-avatar"
|
||||
assert res["data"]["reference"][0]["chunks"] == [{"chunk": "normalized"}]
|
||||
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
|
||||
res = _run(module.get())
|
||||
assert res["message"] == "Conversation not found!"
|
||||
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(args={"conversation_id": "conv-1"}))
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
|
||||
res = _run(module.get())
|
||||
assert res["code"] == module.RetCode.OPERATING_ERROR
|
||||
assert "Only owner of conversation" in res["message"]
|
||||
|
||||
def _raise_get(*_args, **_kwargs):
|
||||
raise RuntimeError("get boom")
|
||||
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", _raise_get)
|
||||
res = _run(module.get())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "get boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(headers={"Authorization": "Bearer"}))
|
||||
res = module.getsse("dialog-1")
|
||||
assert "Authorization is not valid" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(headers={"Authorization": "Bearer token-1"}))
|
||||
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [])
|
||||
res = module.getsse("dialog-1")
|
||||
assert "API key is invalid" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [SimpleNamespace()])
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None))
|
||||
res = module.getsse("dialog-1")
|
||||
assert res["message"] == "Dialog not found!"
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialog()))
|
||||
res = module.getsse("dialog-1")
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["avatar"] == "avatar.png"
|
||||
assert "icon" not in res["data"]
|
||||
|
||||
def _raise_getsse(_id):
|
||||
raise RuntimeError("getsse boom")
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", _raise_getsse)
|
||||
res = module.getsse("dialog-1")
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "getsse boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_rm_and_list_conversation_guards(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]})
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
|
||||
res = _run(module.rm())
|
||||
assert "Conversation not found" in res["message"]
|
||||
|
||||
conv = _DummyConversation(conv_id="conv-1", dialog_id="dialog-1")
|
||||
_set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]})
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == module.RetCode.OPERATING_ERROR
|
||||
|
||||
deleted = []
|
||||
_set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]})
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="dialog-1")])
|
||||
monkeypatch.setattr(module.ConversationService, "delete_by_id", lambda cid: deleted.append(cid) or True)
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
assert deleted == ["conv-1"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"conversation_ids": ["conv-1"]})
|
||||
|
||||
def _raise_rm(*_args, **_kwargs):
|
||||
raise RuntimeError("rm boom")
|
||||
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", _raise_rm)
|
||||
res = _run(module.rm())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "rm boom" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(args={"dialog_id": "dialog-1"}))
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [])
|
||||
res = _run(module.list_conversation())
|
||||
assert res["code"] == module.RetCode.OPERATING_ERROR
|
||||
assert "Only owner of dialog" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="dialog-1")])
|
||||
monkeypatch.setattr(module.ConversationService, "model", SimpleNamespace(create_time="create_time"))
|
||||
monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: [_DummyConversation(conv_id="c1"), _DummyConversation(conv_id="c2")])
|
||||
res = _run(module.list_conversation())
|
||||
assert res["code"] == 0
|
||||
assert [x["id"] for x in res["data"]] == ["c1", "c2"]
|
||||
|
||||
def _raise_list(**_kwargs):
|
||||
raise RuntimeError("list boom")
|
||||
|
||||
monkeypatch.setattr(module.ConversationService, "query", _raise_list)
|
||||
res = _run(module.list_conversation())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "list boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_completion_stream_and_nonstream_branches(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
|
||||
conv = _DummyConversation(conv_id="conv-1", dialog_id="dialog-1", reference=[])
|
||||
dia = _DummyDialog(dialog_id="dialog-1", tenant_id="tenant-1")
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, dia))
|
||||
monkeypatch.setattr(module, "structure_answer", lambda _conv, ans, message_id, conv_id: {"answer": ans["answer"], "id": message_id, "conversation_id": conv_id, "reference": []})
|
||||
|
||||
updates = []
|
||||
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda conv_id, payload: updates.append((conv_id, payload)) or True)
|
||||
|
||||
stream_payload = {
|
||||
"conversation_id": "conv-1",
|
||||
"messages": [
|
||||
{"role": "system", "content": "ignored"},
|
||||
{"role": "assistant", "content": "ignored-first-assistant"},
|
||||
{"role": "user", "content": "hello", "id": "m-1"},
|
||||
],
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
async def _stream_ok(_dia, sanitized, *_args, **_kwargs):
|
||||
assert [m["role"] for m in sanitized] == ["user"]
|
||||
yield {"answer": "sse-ok"}
|
||||
|
||||
monkeypatch.setattr(module, "async_chat", _stream_ok)
|
||||
_set_request_json(monkeypatch, module, stream_payload)
|
||||
resp = _run(module.completion.__wrapped__())
|
||||
assert resp.headers["Content-Type"].startswith("text/event-stream")
|
||||
sse_text = _run(_read_sse_text(resp))
|
||||
assert "sse-ok" in sse_text
|
||||
assert '"data": true' in sse_text
|
||||
assert updates
|
||||
|
||||
async def _stream_error(_dia, _sanitized, *_args, **_kwargs):
|
||||
raise RuntimeError("stream explode")
|
||||
if False:
|
||||
yield {"answer": "never"}
|
||||
|
||||
monkeypatch.setattr(module, "async_chat", _stream_error)
|
||||
_set_request_json(monkeypatch, module, stream_payload)
|
||||
resp = _run(module.completion.__wrapped__())
|
||||
sse_text = _run(_read_sse_text(resp))
|
||||
assert "**ERROR**: stream explode" in sse_text
|
||||
|
||||
async def _non_stream(_dia, _sanitized, **_kwargs):
|
||||
yield {"answer": "plain-ok"}
|
||||
|
||||
monkeypatch.setattr(module, "async_chat", _non_stream)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"conversation_id": "conv-1",
|
||||
"messages": [{"role": "user", "content": "plain", "id": "m-2"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
res = _run(module.completion.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["answer"] == "plain-ok"
|
||||
|
||||
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda **_kwargs: False)
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"conversation_id": "conv-1",
|
||||
"messages": [{"role": "user", "content": "embed", "id": "m-3"}],
|
||||
"llm_id": "bad-model",
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
res = _run(module.completion.__wrapped__())
|
||||
assert "Cannot use specified model bad-model" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda **_kwargs: "api-key")
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"conversation_id": "conv-1",
|
||||
"messages": [{"role": "user", "content": "embed", "id": "m-4"}],
|
||||
"llm_id": "glm-4",
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.2,
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
res = _run(module.completion.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert dia.llm_id == "glm-4"
|
||||
assert dia.llm_setting == {"temperature": 0.7, "top_p": 0.2}
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"conversation_id": "missing",
|
||||
"messages": [{"role": "user", "content": "x", "id": "m-5"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None))
|
||||
res = _run(module.completion.__wrapped__())
|
||||
assert res["message"] == "Conversation not found!"
|
||||
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
|
||||
monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (False, None))
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"conversation_id": "conv-1",
|
||||
"messages": [{"role": "user", "content": "x", "id": "m-6"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
res = _run(module.completion.__wrapped__())
|
||||
assert res["message"] == "Dialog not found!"
|
||||
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (_ for _ in ()).throw(RuntimeError("completion boom")))
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"conversation_id": "conv-1",
|
||||
"messages": [{"role": "user", "content": "x", "id": "m-7"}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
res = _run(module.completion.__wrapped__())
|
||||
assert res["code"] == module.RetCode.EXCEPTION_ERROR
|
||||
assert "completion boom" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={}))
|
||||
res = _run(module.sequence2txt())
|
||||
assert "Missing 'file'" in res["message"]
|
||||
|
||||
bad_file = _DummyUploadedFile("audio.txt")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": bad_file}))
|
||||
res = _run(module.sequence2txt())
|
||||
assert "Unsupported audio format" in res["message"]
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
|
||||
res = _run(module.sequence2txt())
|
||||
assert res["message"] == "Tenant not found!"
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "asr_id": ""}])
|
||||
res = _run(module.sequence2txt())
|
||||
assert res["message"] == "No default ASR model is set"
|
||||
|
||||
class _SyncAsr:
|
||||
def transcription(self, _path):
|
||||
return "transcribed text"
|
||||
|
||||
def stream_transcription(self, _path):
|
||||
return []
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [{"tenant_id": "tenant-1", "asr_id": "asr-model"}])
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncAsr())
|
||||
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("remove failed")))
|
||||
res = _run(module.sequence2txt())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["text"] == "transcribed text"
|
||||
|
||||
class _StreamAsr:
|
||||
def transcription(self, _path):
|
||||
return ""
|
||||
|
||||
def stream_transcription(self, _path):
|
||||
yield {"event": "partial", "text": "hello"}
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "true"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _StreamAsr())
|
||||
resp = _run(module.sequence2txt())
|
||||
assert resp.headers["Content-Type"].startswith("text/event-stream")
|
||||
sse_text = _run(_read_sse_text(resp))
|
||||
assert '"event": "partial"' in sse_text
|
||||
|
||||
class _ErrorStreamAsr:
|
||||
def transcription(self, _path):
|
||||
return ""
|
||||
|
||||
def stream_transcription(self, _path):
|
||||
raise RuntimeError("stream asr boom")
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "true"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _ErrorStreamAsr())
|
||||
resp = _run(module.sequence2txt())
|
||||
sse_text = _run(_read_sse_text(resp))
|
||||
assert "stream asr boom" in sse_text
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_tts_request_parse_entry(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
_set_request_json(monkeypatch, module, {"text": "hello"})
|
||||
monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: [])
|
||||
res = _run(module.tts())
|
||||
assert res["message"] == "Tenant not found!"
|
||||
@ -15,10 +15,22 @@
|
||||
#
|
||||
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from common import bulk_upload_documents, delete_document, list_documents
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_document_func(request, WebApiAuth, add_dataset, ragflow_tmp_dir):
|
||||
def cleanup():
|
||||
@ -56,3 +68,49 @@ def add_documents_func(request, WebApiAuth, add_dataset_func, ragflow_tmp_dir):
|
||||
|
||||
dataset_id = add_dataset_func
|
||||
return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 3, ragflow_tmp_dir)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def document_app_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
deepdoc_pkg = ModuleType("deepdoc")
|
||||
deepdoc_parser_pkg = ModuleType("deepdoc.parser")
|
||||
deepdoc_parser_pkg.__path__ = []
|
||||
|
||||
class _StubPdfParser:
|
||||
pass
|
||||
|
||||
class _StubExcelParser:
|
||||
pass
|
||||
|
||||
deepdoc_parser_pkg.PdfParser = _StubPdfParser
|
||||
deepdoc_pkg.parser = deepdoc_parser_pkg
|
||||
monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg)
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg)
|
||||
deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser")
|
||||
deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module)
|
||||
deepdoc_html_module = ModuleType("deepdoc.parser.html_parser")
|
||||
|
||||
class _StubHtmlParser:
|
||||
pass
|
||||
|
||||
deepdoc_html_module.RAGFlowHtmlParser = _StubHtmlParser
|
||||
monkeypatch.setitem(sys.modules, "deepdoc.parser.html_parser", deepdoc_html_module)
|
||||
monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost"))
|
||||
|
||||
stub_apps = ModuleType("api.apps")
|
||||
stub_apps.current_user = SimpleNamespace(id="user-1")
|
||||
stub_apps.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", stub_apps)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "document_app.py"
|
||||
spec = importlib.util.spec_from_file_location("test_document_app_unit", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
@ -13,7 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import string
|
||||
from types import SimpleNamespace
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
@ -21,6 +23,7 @@ from common import create_document, list_kbs
|
||||
from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from utils.file_utils import create_txt_file
|
||||
from api.constants import FILE_NAME_LEN_LIMIT
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
@ -90,3 +93,130 @@ class TestDocumentCreate:
|
||||
|
||||
res = list_kbs(WebApiAuth, {"id": kb_id})
|
||||
assert res["data"]["kbs"][0]["doc_num"] == count, res
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestDocumentCreateUnit:
|
||||
def test_missing_kb_id(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "", "name": "doc.txt"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == 'Lack of "KB ID"'
|
||||
|
||||
def test_filename_too_long(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
long_name = "a" * (FILE_NAME_LEN_LIMIT + 1)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "name": long_name}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less."
|
||||
|
||||
def test_filename_whitespace(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "name": " "}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == "File name can't be empty."
|
||||
|
||||
def test_kb_not_found(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "missing", "name": "doc.txt"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "Can't find this dataset!"
|
||||
|
||||
def test_duplicate_name(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", pipeline_id="pipe", parser_config={})
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [object()])
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "name": "doc.txt"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert "Duplicated document name" in res["message"]
|
||||
|
||||
def test_root_folder_missing(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", pipeline_id="pipe", parser_config={})
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [])
|
||||
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: None)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "name": "doc.txt"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "Cannot find the root folder."
|
||||
|
||||
def test_kb_folder_missing(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", pipeline_id="pipe", parser_config={})
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [])
|
||||
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: {"id": "root"})
|
||||
monkeypatch.setattr(module.FileService, "new_a_file_from_kb", lambda *_args, **_kwargs: None)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "name": "doc.txt"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert res["message"] == "Cannot find the kb folder for this file."
|
||||
|
||||
def test_success(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", pipeline_id="pipe", parser_config={})
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [])
|
||||
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: {"id": "root"})
|
||||
monkeypatch.setattr(module.FileService, "new_a_file_from_kb", lambda *_args, **_kwargs: {"id": "folder"})
|
||||
|
||||
class _Doc:
|
||||
def __init__(self, doc_id):
|
||||
self.id = doc_id
|
||||
|
||||
def to_json(self):
|
||||
return {"id": self.id, "name": "doc.txt", "kb_id": "kb1"}
|
||||
|
||||
def to_dict(self):
|
||||
return {"id": self.id, "name": "doc.txt", "kb_id": "kb1"}
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "insert", lambda _doc: _Doc("doc1"))
|
||||
monkeypatch.setattr(module.FileService, "add_file_from_kb", lambda *_args, **_kwargs: None)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "name": "doc.txt"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "doc1"
|
||||
|
||||
@ -13,6 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
document_change_status,
|
||||
@ -241,3 +244,170 @@ class TestDocumentMetadataNegative:
|
||||
res = document_set_meta(WebApiAuth, {"doc_id": doc_id, "meta": "[]"})
|
||||
assert res["code"] == 101, res
|
||||
assert "dictionary" in res["message"], res
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestDocumentMetadataUnit:
|
||||
def _allow_kb(self, module, monkeypatch, kb_id="kb1", tenant_id="tenant1"):
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id=tenant_id)])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: True if _kwargs.get("id") == kb_id else False)
|
||||
|
||||
def test_filter_missing_kb_id(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
async def fake_request_json():
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 101
|
||||
assert "KB ID" in res["message"]
|
||||
|
||||
def test_filter_unauthorized(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant1")])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 103
|
||||
|
||||
def test_filter_invalid_filters(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._allow_kb(module, monkeypatch)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "run_status": ["INVALID"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 102
|
||||
assert "Invalid filter run status" in res["message"]
|
||||
|
||||
async def fake_request_json_types():
|
||||
return {"kb_id": "kb1", "types": ["INVALID"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json_types)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 102
|
||||
assert "Invalid filter conditions" in res["message"]
|
||||
|
||||
def test_filter_keywords_suffix(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._allow_kb(module, monkeypatch)
|
||||
monkeypatch.setattr(module.DocumentService, "get_filter_by_kb_id", lambda *_args, **_kwargs: ({"run": {}}, 1))
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "keywords": "ragflow", "suffix": ["txt"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 0
|
||||
assert "filter" in res["data"]
|
||||
|
||||
def test_filter_exception(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._allow_kb(module, monkeypatch)
|
||||
|
||||
def raise_error(*_args, **_kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_filter_by_kb_id", raise_error)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.get_filter())
|
||||
assert res["code"] == 100
|
||||
|
||||
def test_infos_meta_fields(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda *_args, **_kwargs: True)
|
||||
|
||||
class _Docs:
|
||||
def dicts(self):
|
||||
return [{"id": "doc1"}]
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_ids", lambda _ids: _Docs())
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_document_metadata", lambda _doc_id: {"author": "alice"})
|
||||
|
||||
async def fake_request_json():
|
||||
return {"doc_ids": ["doc1"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.doc_infos())
|
||||
assert res["code"] == 0
|
||||
assert res["data"][0]["meta_fields"]["author"] == "alice"
|
||||
|
||||
def test_metadata_summary_missing_kb_id(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
async def fake_request_json():
|
||||
return {"doc_ids": ["doc1"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.metadata_summary())
|
||||
assert res["code"] == 101
|
||||
|
||||
def test_metadata_summary_unauthorized(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant1")])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "doc_ids": ["doc1"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.metadata_summary())
|
||||
assert res["code"] == 103
|
||||
|
||||
def test_metadata_summary_success_and_exception(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._allow_kb(module, monkeypatch)
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_metadata_summary", lambda *_args, **_kwargs: {"author": {"alice": 1}})
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "doc_ids": ["doc1"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.metadata_summary())
|
||||
assert res["code"] == 0
|
||||
assert "summary" in res["data"]
|
||||
|
||||
def raise_error(*_args, **_kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_metadata_summary", raise_error)
|
||||
res = _run(module.metadata_summary())
|
||||
assert res["code"] == 100
|
||||
|
||||
def test_metadata_update_missing_kb_id(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
async def fake_request_json():
|
||||
return {"doc_ids": ["doc1"], "updates": [], "deletes": []}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.metadata_update.__wrapped__())
|
||||
assert res["code"] == 101
|
||||
assert "KB ID" in res["message"]
|
||||
|
||||
def test_metadata_update_success(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module.DocMetadataService, "batch_update_metadata", lambda *_args, **_kwargs: 1)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"kb_id": "kb1", "doc_ids": ["doc1"], "updates": [{"key": "author", "value": "alice"}], "deletes": []}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.metadata_update.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["matched_docs"] == 1
|
||||
|
||||
@ -13,7 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from common import list_documents
|
||||
@ -178,3 +180,214 @@ class TestDocumentsList:
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures), responses
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
class _DummyArgs(dict):
|
||||
def get(self, key, default=None):
|
||||
return super().get(key, default)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestDocumentsListUnit:
|
||||
def _set_args(self, module, monkeypatch, **kwargs):
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs(kwargs)))
|
||||
|
||||
def _allow_kb(self, module, monkeypatch, kb_id="kb1", tenant_id="tenant1"):
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id=tenant_id)])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: True if _kwargs.get("id") == kb_id else False)
|
||||
|
||||
def test_missing_kb_id(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._set_args(module, monkeypatch)
|
||||
|
||||
async def fake_request_json():
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == 'Lack of "KB ID"'
|
||||
|
||||
def test_unauthorized_dataset(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._set_args(module, monkeypatch, kb_id="kb1")
|
||||
monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant1")])
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: False)
|
||||
|
||||
async def fake_request_json():
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 103
|
||||
assert "Only owner of dataset" in res["message"]
|
||||
|
||||
def test_return_empty_metadata_flags(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._set_args(module, monkeypatch, kb_id="kb1")
|
||||
self._allow_kb(module, monkeypatch)
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", lambda *_args, **_kwargs: ([], 0))
|
||||
|
||||
async def fake_request_json():
|
||||
return {"return_empty_metadata": "true", "metadata": {"author": "alice"}}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 0
|
||||
|
||||
async def fake_request_json_empty():
|
||||
return {"metadata": {"empty_metadata": True, "author": "alice"}}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json_empty)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 0
|
||||
|
||||
def test_invalid_filters(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._set_args(module, monkeypatch, kb_id="kb1")
|
||||
self._allow_kb(module, monkeypatch)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"run_status": ["INVALID"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 102
|
||||
assert "Invalid filter run status" in res["message"]
|
||||
|
||||
async def fake_request_json_types():
|
||||
return {"types": ["INVALID"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json_types)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 102
|
||||
assert "Invalid filter conditions" in res["message"]
|
||||
|
||||
def test_invalid_metadata_types(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._set_args(module, monkeypatch, kb_id="kb1")
|
||||
self._allow_kb(module, monkeypatch)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"metadata_condition": "bad"}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 102
|
||||
assert "metadata_condition" in res["message"]
|
||||
|
||||
async def fake_request_json_meta():
|
||||
return {"metadata": ["not", "object"]}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json_meta)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 102
|
||||
assert "metadata must be an object" in res["message"]
|
||||
|
||||
def test_metadata_condition_empty_result(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._set_args(module, monkeypatch, kb_id="kb1")
|
||||
self._allow_kb(module, monkeypatch)
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda *_args, **_kwargs: {})
|
||||
monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: set())
|
||||
|
||||
async def fake_request_json():
|
||||
return {"metadata_condition": {"conditions": [{"name": "author", "comparison_operator": "is", "value": "alice"}]}}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 0
|
||||
|
||||
def test_metadata_values_intersection(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._set_args(module, monkeypatch, kb_id="kb1")
|
||||
self._allow_kb(module, monkeypatch)
|
||||
metas = {
|
||||
"author": {"alice": ["doc1", "doc2"]},
|
||||
"topic": {"rag": ["doc2"]},
|
||||
}
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda *_args, **_kwargs: metas)
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_get_by_kb_id(*_args, **_kwargs):
|
||||
if len(_args) >= 10:
|
||||
captured["doc_ids_filter"] = _args[9]
|
||||
else:
|
||||
captured["doc_ids_filter"] = None
|
||||
return ([{"id": "doc2", "thumbnail": "", "parser_config": {}}], 1)
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", fake_get_by_kb_id)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"metadata": {"author": ["alice", " ", None], "topic": "rag"}}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 0
|
||||
assert captured["doc_ids_filter"] == ["doc2"]
|
||||
|
||||
def test_metadata_intersection_empty(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._set_args(module, monkeypatch, kb_id="kb1")
|
||||
self._allow_kb(module, monkeypatch)
|
||||
metas = {
|
||||
"author": {"alice": ["doc1"]},
|
||||
"topic": {"rag": ["doc2"]},
|
||||
}
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda *_args, **_kwargs: metas)
|
||||
|
||||
async def fake_request_json():
|
||||
return {"metadata": {"author": "alice", "topic": "rag"}}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 0
|
||||
|
||||
def test_desc_time_and_schema(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._set_args(module, monkeypatch, kb_id="kb1", desc="false", create_time_from="150", create_time_to="250")
|
||||
self._allow_kb(module, monkeypatch)
|
||||
|
||||
docs = [
|
||||
{"id": "doc1", "thumbnail": "", "parser_config": {"metadata": {"a": 1}}, "create_time": 100},
|
||||
{"id": "doc2", "thumbnail": "", "parser_config": {"metadata": {"b": 2}}, "create_time": 200},
|
||||
]
|
||||
|
||||
def fake_get_by_kb_id(*_args, **_kwargs):
|
||||
return (docs, 2)
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", fake_get_by_kb_id)
|
||||
monkeypatch.setattr(module, "turn2jsonschema", lambda _meta: {"schema": True})
|
||||
|
||||
async def fake_request_json():
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]["docs"]) == 1
|
||||
assert res["data"]["docs"][0]["parser_config"]["metadata"] == {"schema": True}
|
||||
|
||||
def test_exception_path(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
self._set_args(module, monkeypatch, kb_id="kb1")
|
||||
self._allow_kb(module, monkeypatch)
|
||||
|
||||
def raise_error(*_args, **_kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(module.DocumentService, "get_by_kb_id", raise_error)
|
||||
|
||||
async def fake_request_json():
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", fake_request_json)
|
||||
res = _run(module.list_docs())
|
||||
assert res["code"] == 100
|
||||
|
||||
@ -13,7 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import string
|
||||
from types import SimpleNamespace
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
@ -21,6 +23,7 @@ from common import list_kbs, upload_documents
|
||||
from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from utils.file_utils import create_txt_file
|
||||
from api.constants import FILE_NAME_LEN_LIMIT
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
@ -189,3 +192,288 @@ class TestDocumentsUpload:
|
||||
|
||||
res = list_kbs(WebApiAuth)
|
||||
assert res["data"]["kbs"][0]["doc_num"] == count, res
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _coro():
|
||||
return self._value
|
||||
|
||||
return _coro().__await__()
|
||||
|
||||
|
||||
class _DummyFiles(dict):
|
||||
def getlist(self, key):
|
||||
value = self.get(key, [])
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return [value]
|
||||
|
||||
|
||||
class _DummyFile:
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.closed = False
|
||||
self.stream = self
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class _DummyRequest:
|
||||
def __init__(self, form=None, files=None):
|
||||
self._form = form or {}
|
||||
self._files = files or _DummyFiles()
|
||||
|
||||
@property
|
||||
def form(self):
|
||||
return _AwaitableValue(self._form)
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
return _AwaitableValue(self._files)
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestDocumentsUploadUnit:
|
||||
def test_missing_kb_id(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": ""}, files=_DummyFiles()))
|
||||
res = _run(module.upload.__wrapped__())
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == 'Lack of "KB ID"'
|
||||
|
||||
def test_missing_file_part(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=_DummyFiles()))
|
||||
res = _run(module.upload.__wrapped__())
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == "No file part!"
|
||||
|
||||
def test_empty_filename_closes_files(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
file_obj = _DummyFile("")
|
||||
files = _DummyFiles({"file": [file_obj]})
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=files))
|
||||
res = _run(module.upload.__wrapped__())
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == "No file selected!"
|
||||
assert file_obj.closed is True
|
||||
|
||||
def test_filename_too_long(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
long_name = "a" * (FILE_NAME_LEN_LIMIT + 1)
|
||||
file_obj = _DummyFile(long_name)
|
||||
files = _DummyFiles({"file": [file_obj]})
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=files))
|
||||
res = _run(module.upload.__wrapped__())
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less."
|
||||
|
||||
def test_invalid_kb_id_raises(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
file_obj = _DummyFile("ragflow_test.txt")
|
||||
files = _DummyFiles({"file": [file_obj]})
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "missing"}, files=files))
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
with pytest.raises(LookupError):
|
||||
_run(module.upload.__wrapped__())
|
||||
|
||||
def test_no_permission(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: False)
|
||||
file_obj = _DummyFile("ragflow_test.txt")
|
||||
files = _DummyFiles({"file": [file_obj]})
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=files))
|
||||
res = _run(module.upload.__wrapped__())
|
||||
assert res["code"] == 109
|
||||
assert res["message"] == "No authorization."
|
||||
|
||||
def test_thread_pool_errors(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
|
||||
|
||||
async def fake_thread_pool_exec(*_args, **_kwargs):
|
||||
return (["unsupported type"], [("file1", "blob")])
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
|
||||
file_obj = _DummyFile("ragflow_test.txt")
|
||||
files = _DummyFiles({"file": [file_obj]})
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=files))
|
||||
res = _run(module.upload.__wrapped__())
|
||||
assert res["code"] == 500
|
||||
assert "unsupported type" in res["message"]
|
||||
assert res["data"] == ["file1"]
|
||||
|
||||
def test_empty_upload_result(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
|
||||
|
||||
async def fake_thread_pool_exec(*_args, **_kwargs):
|
||||
return (None, [])
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec)
|
||||
file_obj = _DummyFile("ragflow_test.txt")
|
||||
files = _DummyFiles({"file": [file_obj]})
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1"}, files=files))
|
||||
res = _run(module.upload.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert "file format" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
class TestWebCrawlUnit:
|
||||
def test_missing_kb_id(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "", "name": "doc", "url": "http://example.com"}))
|
||||
res = _run(module.web_crawl.__wrapped__())
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == 'Lack of "KB ID"'
|
||||
|
||||
def test_invalid_url(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "not-a-url"}))
|
||||
res = _run(module.web_crawl.__wrapped__())
|
||||
assert res["code"] == 101
|
||||
assert res["message"] == "The URL format is invalid"
|
||||
|
||||
def test_invalid_kb_id_raises(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None))
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "missing", "name": "doc", "url": "http://example.com"}))
|
||||
with pytest.raises(LookupError):
|
||||
_run(module.web_crawl.__wrapped__())
|
||||
|
||||
def test_no_permission(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
|
||||
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: False)
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "http://example.com"}))
|
||||
res = _run(module.web_crawl.__wrapped__())
|
||||
assert res["code"] == 109
|
||||
assert res["message"] == "No authorization."
|
||||
|
||||
def test_download_failure(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
|
||||
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module, "html2pdf", lambda _url: None)
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "http://example.com"}))
|
||||
res = _run(module.web_crawl.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "Download failure" in res["message"]
|
||||
|
||||
def test_unsupported_type(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
|
||||
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module, "html2pdf", lambda _url: b"%PDF-1.4")
|
||||
monkeypatch.setattr(module.FileService, "get_root_folder", lambda _uid: {"id": "root"})
|
||||
monkeypatch.setattr(module.FileService, "init_knowledgebase_docs", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: {"id": "kb_root"})
|
||||
monkeypatch.setattr(module.FileService, "new_a_file_from_kb", lambda *_args, **_kwargs: {"id": "kb_folder"})
|
||||
monkeypatch.setattr(module, "duplicate_name", lambda *_args, **_kwargs: "bad.exe")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "http://example.com"}))
|
||||
res = _run(module.web_crawl.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "supported yet" in res["message"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename,filetype,expected_parser",
|
||||
[
|
||||
("image.png", "visual", "picture"),
|
||||
("sound.mp3", "aural", "audio"),
|
||||
("deck.pptx", "doc", "presentation"),
|
||||
("mail.eml", "doc", "email"),
|
||||
],
|
||||
)
|
||||
def test_success_parser_overrides(self, document_app_module, monkeypatch, filename, filetype, expected_parser):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
|
||||
captured = {}
|
||||
|
||||
class _Storage:
|
||||
def obj_exist(self, *_args, **_kwargs):
|
||||
return False
|
||||
|
||||
def put(self, *_args, **_kwargs):
|
||||
captured["put"] = True
|
||||
|
||||
def insert_doc(doc):
|
||||
captured["doc"] = doc
|
||||
|
||||
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module, "html2pdf", lambda _url: b"%PDF-1.4")
|
||||
monkeypatch.setattr(module.FileService, "get_root_folder", lambda _uid: {"id": "root"})
|
||||
monkeypatch.setattr(module.FileService, "init_knowledgebase_docs", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: {"id": "kb_root"})
|
||||
monkeypatch.setattr(module.FileService, "new_a_file_from_kb", lambda *_args, **_kwargs: {"id": "kb_folder"})
|
||||
monkeypatch.setattr(module, "duplicate_name", lambda *_args, **_kwargs: filename)
|
||||
monkeypatch.setattr(module, "filename_type", lambda _name: filetype)
|
||||
monkeypatch.setattr(module, "thumbnail", lambda *_args, **_kwargs: "")
|
||||
monkeypatch.setattr(module, "get_uuid", lambda: "doc-1")
|
||||
monkeypatch.setattr(module.settings, "STORAGE_IMPL", _Storage())
|
||||
monkeypatch.setattr(module.DocumentService, "insert", insert_doc)
|
||||
monkeypatch.setattr(module.FileService, "add_file_from_kb", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "http://example.com"}))
|
||||
|
||||
res = _run(module.web_crawl.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert captured["doc"]["parser_id"] == expected_parser
|
||||
assert captured["put"] is True
|
||||
|
||||
def test_exception_path(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
kb = SimpleNamespace(id="kb1", tenant_id="tenant1", name="kb", parser_id="parser", parser_config={})
|
||||
|
||||
class _Storage:
|
||||
def obj_exist(self, *_args, **_kwargs):
|
||||
return False
|
||||
|
||||
def put(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
def insert_doc(_doc):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(module, "is_valid_url", lambda _url: True)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb))
|
||||
monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True)
|
||||
monkeypatch.setattr(module, "html2pdf", lambda _url: b"%PDF-1.4")
|
||||
monkeypatch.setattr(module.FileService, "get_root_folder", lambda _uid: {"id": "root"})
|
||||
monkeypatch.setattr(module.FileService, "init_knowledgebase_docs", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(module.FileService, "get_kb_folder", lambda *_args, **_kwargs: {"id": "kb_root"})
|
||||
monkeypatch.setattr(module.FileService, "new_a_file_from_kb", lambda *_args, **_kwargs: {"id": "kb_folder"})
|
||||
monkeypatch.setattr(module, "duplicate_name", lambda *_args, **_kwargs: "doc.pdf")
|
||||
monkeypatch.setattr(module, "filename_type", lambda _name: "pdf")
|
||||
monkeypatch.setattr(module, "thumbnail", lambda *_args, **_kwargs: "")
|
||||
monkeypatch.setattr(module, "get_uuid", lambda: "doc-1")
|
||||
monkeypatch.setattr(module.settings, "STORAGE_IMPL", _Storage())
|
||||
monkeypatch.setattr(module.DocumentService, "insert", insert_doc)
|
||||
monkeypatch.setattr(module.FileService, "add_file_from_kb", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"kb_id": "kb1", "name": "doc", "url": "http://example.com"}))
|
||||
|
||||
res = _run(module.web_crawl.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
|
||||
@ -206,3 +206,25 @@ class TestKbPipelineLogs:
|
||||
res = kb_delete_pipeline_logs(WebApiAuth, params={"kb_id": kb_id}, payload={"log_ids": []})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"] is True, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_list_pipeline_logs_missing_kb_id(self, WebApiAuth):
|
||||
res = kb_list_pipeline_logs(WebApiAuth, params={}, payload={})
|
||||
assert res["code"] == 101, res
|
||||
assert "KB ID" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_list_pipeline_logs_abnormal_date_filter(self, WebApiAuth, add_document):
|
||||
kb_id, _ = add_document
|
||||
res = kb_list_pipeline_logs(
|
||||
WebApiAuth,
|
||||
params={
|
||||
"kb_id": kb_id,
|
||||
"desc": "false",
|
||||
"create_date_from": "2025-01-01",
|
||||
"create_date_to": "2025-02-01",
|
||||
},
|
||||
payload={},
|
||||
)
|
||||
assert res["code"] == 102, res
|
||||
assert "Create data filter is abnormal." in res["message"], res
|
||||
|
||||
1048
test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py
Normal file
1048
test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -17,9 +17,11 @@ import uuid
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
delete_knowledge_graph,
|
||||
kb_basic_info,
|
||||
kb_get_meta,
|
||||
kb_update_metadata_setting,
|
||||
knowledge_graph,
|
||||
list_tags,
|
||||
list_tags_from_kbs,
|
||||
rename_tags,
|
||||
@ -121,6 +123,20 @@ class TestAuthorization:
|
||||
assert res["code"] == expected_code, res
|
||||
assert expected_fragment in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES)
|
||||
def test_knowledge_graph_auth_invalid(self, invalid_auth, expected_code, expected_fragment):
|
||||
res = knowledge_graph(invalid_auth, "kb_id")
|
||||
assert res["code"] == expected_code, res
|
||||
assert expected_fragment in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES)
|
||||
def test_delete_knowledge_graph_auth_invalid(self, invalid_auth, expected_code, expected_fragment):
|
||||
res = delete_knowledge_graph(invalid_auth, "kb_id")
|
||||
assert res["code"] == expected_code, res
|
||||
assert expected_fragment in res["message"], res
|
||||
|
||||
|
||||
class TestKbTagsMeta:
|
||||
@pytest.mark.p2
|
||||
@ -205,6 +221,22 @@ class TestKbTagsMeta:
|
||||
assert res["data"]["id"] == kb_id, res
|
||||
assert res["data"]["parser_config"]["metadata"] == metadata, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_knowledge_graph(self, WebApiAuth, add_dataset):
|
||||
kb_id = add_dataset
|
||||
res = knowledge_graph(WebApiAuth, kb_id)
|
||||
assert res["code"] == 0, res
|
||||
assert isinstance(res["data"], dict), res
|
||||
assert "graph" in res["data"], res
|
||||
assert "mind_map" in res["data"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_delete_knowledge_graph(self, WebApiAuth, add_dataset):
|
||||
kb_id = add_dataset
|
||||
res = delete_knowledge_graph(WebApiAuth, kb_id)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"] is True, res
|
||||
|
||||
|
||||
class TestKbTagsMetaNegative:
|
||||
@pytest.mark.p3
|
||||
@ -249,3 +281,15 @@ class TestKbTagsMetaNegative:
|
||||
assert res["code"] == 101, res
|
||||
assert "required argument are missing" in res["message"], res
|
||||
assert "metadata" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_knowledge_graph_invalid_kb(self, WebApiAuth):
|
||||
res = knowledge_graph(WebApiAuth, "invalid_kb_id")
|
||||
assert res["code"] == 109, res
|
||||
assert "No authorization" in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_delete_knowledge_graph_invalid_kb(self, WebApiAuth):
|
||||
res = delete_knowledge_graph(WebApiAuth, "invalid_kb_id")
|
||||
assert res["code"] == 109, res
|
||||
assert "No authorization" in res["message"], res
|
||||
|
||||
@ -182,3 +182,20 @@ class TestDatasetsList:
|
||||
res = list_kbs(WebApiAuth, params)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["kbs"]) == expected_page_size, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_owner_ids_payload_mode(self, WebApiAuth):
|
||||
base_res = list_kbs(WebApiAuth, {"page_size": 10})
|
||||
assert base_res["code"] == 0, base_res
|
||||
assert base_res["data"]["kbs"], base_res
|
||||
owner_id = base_res["data"]["kbs"][0]["tenant_id"]
|
||||
|
||||
res = list_kbs(
|
||||
WebApiAuth,
|
||||
params={"page": 1, "page_size": 2, "desc": "false"},
|
||||
payload={"owner_ids": [owner_id]},
|
||||
)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["total"] >= len(res["data"]["kbs"]), res
|
||||
assert len(res["data"]["kbs"]) <= 2, res
|
||||
assert all(kb["tenant_id"] == owner_id for kb in res["data"]["kbs"]), res
|
||||
|
||||
290
test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py
Normal file
290
test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py
Normal file
@ -0,0 +1,290 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _ExprField:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.name, other)
|
||||
|
||||
|
||||
class _DummyTenantLLMModel:
|
||||
tenant_id = _ExprField("tenant_id")
|
||||
llm_factory = _ExprField("llm_factory")
|
||||
|
||||
|
||||
class _TenantLLMRow:
|
||||
def __init__(self, *, llm_name, llm_factory, model_type, api_key="key", status="1"):
|
||||
self.llm_name = llm_name
|
||||
self.llm_factory = llm_factory
|
||||
self.model_type = model_type
|
||||
self.api_key = api_key
|
||||
self.status = status
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"llm_name": self.llm_name,
|
||||
"llm_factory": self.llm_factory,
|
||||
"model_type": self.model_type,
|
||||
"status": self.status,
|
||||
}
|
||||
|
||||
|
||||
class _LLMRow:
|
||||
def __init__(self, *, llm_name, fid, model_type, status="1"):
|
||||
self.llm_name = llm_name
|
||||
self.fid = fid
|
||||
self.model_type = model_type
|
||||
self.status = status
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"llm_name": self.llm_name,
|
||||
"fid": self.fid,
|
||||
"model_type": self.model_type,
|
||||
"status": self.status,
|
||||
}
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _load_llm_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
quart_mod = ModuleType("quart")
|
||||
quart_mod.request = SimpleNamespace(args={})
|
||||
monkeypatch.setitem(sys.modules, "quart", quart_mod)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.login_required = lambda fn: fn
|
||||
apps_mod.current_user = SimpleNamespace(id="tenant-1")
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
tenant_llm_mod = ModuleType("api.db.services.tenant_llm_service")
|
||||
|
||||
class _StubLLMFactoriesService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
class _StubTenantLLMService:
|
||||
@staticmethod
|
||||
def ensure_mineru_from_env(_tenant_id):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_my_llms(_tenant_id):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def save(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_delete(_filters):
|
||||
return True
|
||||
|
||||
tenant_llm_mod.LLMFactoriesService = _StubLLMFactoriesService
|
||||
tenant_llm_mod.TenantLLMService = _StubTenantLLMService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_mod)
|
||||
|
||||
llm_service_mod = ModuleType("api.db.services.llm_service")
|
||||
|
||||
class _StubLLMService:
|
||||
@staticmethod
|
||||
def get_all():
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
llm_service_mod.LLMService = _StubLLMService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
api_utils_mod.get_allowed_llm_factories = lambda: []
|
||||
api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
api_utils_mod.get_json_result = lambda data=None, message="", code=0: {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
async def _get_request_json():
|
||||
return {}
|
||||
|
||||
api_utils_mod.get_request_json = _get_request_json
|
||||
api_utils_mod.server_error_response = lambda exc: {"code": 500, "message": str(exc), "data": None}
|
||||
api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn)
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
constants_mod = ModuleType("common.constants")
|
||||
constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value="1"), INVALID=SimpleNamespace(value="0"))
|
||||
constants_mod.LLMType = SimpleNamespace(
|
||||
CHAT="chat",
|
||||
EMBEDDING="embedding",
|
||||
SPEECH2TEXT="speech2text",
|
||||
IMAGE2TEXT="image2text",
|
||||
RERANK="rerank",
|
||||
TTS="tts",
|
||||
OCR="ocr",
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "common.constants", constants_mod)
|
||||
|
||||
db_models_mod = ModuleType("api.db.db_models")
|
||||
db_models_mod.TenantLLM = _DummyTenantLLMModel
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
base64_mod = ModuleType("rag.utils.base64_image")
|
||||
base64_mod.test_image = lambda _s: _s
|
||||
monkeypatch.setitem(sys.modules, "rag.utils.base64_image", base64_mod)
|
||||
|
||||
rag_llm_mod = ModuleType("rag.llm")
|
||||
rag_llm_mod.EmbeddingModel = {}
|
||||
rag_llm_mod.ChatModel = {}
|
||||
rag_llm_mod.RerankModel = {}
|
||||
rag_llm_mod.CvModel = {}
|
||||
rag_llm_mod.TTSModel = {}
|
||||
rag_llm_mod.OcrModel = {}
|
||||
rag_llm_mod.Seq2txtModel = {}
|
||||
monkeypatch.setitem(sys.modules, "rag.llm", rag_llm_mod)
|
||||
|
||||
module_path = repo_root / "api" / "apps" / "llm_app.py"
|
||||
spec = importlib.util.spec_from_file_location("test_llm_list_unit_module", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_app_grouping_availability_and_merge(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
ensure_calls = []
|
||||
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda tenant_id: ensure_calls.append(tenant_id))
|
||||
|
||||
tenant_rows = [
|
||||
_TenantLLMRow(llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
|
||||
_TenantLLMRow(llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
|
||||
]
|
||||
monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: tenant_rows)
|
||||
|
||||
all_llms = [
|
||||
_LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"),
|
||||
_LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"),
|
||||
_LLMRow(llm_name="not-in-status", fid="Other", model_type="chat", status="1"),
|
||||
]
|
||||
monkeypatch.setattr(module.LLMService, "get_all", lambda: all_llms)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
|
||||
monkeypatch.setenv("COMPOSE_PROFILES", "tei-cpu")
|
||||
monkeypatch.setenv("TEI_MODEL", "tei-embed")
|
||||
|
||||
res = _run(module.list_app())
|
||||
assert res["code"] == 0
|
||||
assert ensure_calls == ["tenant-1"]
|
||||
|
||||
data = res["data"]
|
||||
assert {"Builtin", "FastEmbed", "CustomFactory"}.issubset(set(data.keys()))
|
||||
|
||||
builtin = data["Builtin"][0]
|
||||
assert builtin["llm_name"] == "tei-embed"
|
||||
assert builtin["available"] is True
|
||||
|
||||
fastembed = data["FastEmbed"][0]
|
||||
assert fastembed["llm_name"] == "fast-emb"
|
||||
assert fastembed["available"] is True
|
||||
|
||||
tenant_only = data["CustomFactory"][0]
|
||||
assert tenant_only["llm_name"] == "tenant-only"
|
||||
assert tenant_only["available"] is True
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_app_model_type_filter(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: None)
|
||||
monkeypatch.setattr(
|
||||
module.TenantLLMService,
|
||||
"query",
|
||||
lambda **_kwargs: [
|
||||
_TenantLLMRow(llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"),
|
||||
_TenantLLMRow(llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"),
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module.LLMService,
|
||||
"get_all",
|
||||
lambda: [
|
||||
_LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"),
|
||||
_LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"),
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={"model_type": "chat"}))
|
||||
res = _run(module.list_app())
|
||||
assert res["code"] == 0
|
||||
assert list(res["data"].keys()) == ["CustomFactory"]
|
||||
assert res["data"]["CustomFactory"][0]["model_type"] == "chat"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_app_exception_path(monkeypatch):
|
||||
module = _load_llm_app(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
|
||||
monkeypatch.setattr(module.TenantLLMService, "ensure_mineru_from_env", lambda _tenant_id: None)
|
||||
monkeypatch.setattr(
|
||||
module.TenantLLMService,
|
||||
"query",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("query boom")),
|
||||
)
|
||||
|
||||
res = _run(module.list_app())
|
||||
assert res["code"] == 500
|
||||
assert "query boom" in res["message"]
|
||||
@ -0,0 +1,708 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import inspect
|
||||
import sys
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _Field:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.name, other)
|
||||
|
||||
|
||||
class _DummyMCPServer:
|
||||
id = _Field("id")
|
||||
tenant_id = _Field("tenant_id")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.get("id", "")
|
||||
self.name = kwargs.get("name", "")
|
||||
self.url = kwargs.get("url", "")
|
||||
self.server_type = kwargs.get("server_type", "sse")
|
||||
self.tenant_id = kwargs.get("tenant_id", "tenant_1")
|
||||
self.variables = kwargs.get("variables", {})
|
||||
self.headers = kwargs.get("headers", {})
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"url": self.url,
|
||||
"server_type": self.server_type,
|
||||
"tenant_id": self.tenant_id,
|
||||
"variables": self.variables,
|
||||
"headers": self.headers,
|
||||
}
|
||||
|
||||
|
||||
class _DummyMCPServerService:
|
||||
@staticmethod
|
||||
def get_servers(*_args, **_kwargs):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_or_none(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(*_args, **_kwargs):
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def get_by_name_and_tenant(*_args, **_kwargs):
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def insert(**_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_update(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def delete_by_ids(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
|
||||
class _DummyTenantService:
|
||||
@staticmethod
|
||||
def get_by_id(*_args, **_kwargs):
|
||||
return True, SimpleNamespace(id="tenant_1")
|
||||
|
||||
|
||||
class _DummyTool:
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
|
||||
def model_dump(self):
|
||||
return {"name": self._name}
|
||||
|
||||
|
||||
class _DummyMCPToolCallSession:
|
||||
def __init__(self, _mcp_server, _variables):
|
||||
self._tools = [_DummyTool("tool_a"), _DummyTool("tool_b")]
|
||||
|
||||
def get_tools(self, _timeout):
|
||||
return self._tools
|
||||
|
||||
def tool_call(self, _name, _arguments, _timeout):
|
||||
return "ok"
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
async def _request_json():
|
||||
return payload
|
||||
|
||||
monkeypatch.setattr(module, "get_request_json", _request_json)
|
||||
|
||||
|
||||
def _load_mcp_server_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.current_user = SimpleNamespace(id="tenant_1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
db_models_mod = ModuleType("api.db.db_models")
|
||||
db_models_mod.MCPServer = _DummyMCPServer
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
mcp_service_mod = ModuleType("api.db.services.mcp_server_service")
|
||||
mcp_service_mod.MCPServerService = _DummyMCPServerService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.mcp_server_service", mcp_service_mod)
|
||||
|
||||
user_service_mod = ModuleType("api.db.services.user_service")
|
||||
user_service_mod.TenantService = _DummyTenantService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)
|
||||
|
||||
mcp_conn_mod = ModuleType("common.mcp_tool_call_conn")
|
||||
mcp_conn_mod.MCPToolCallSession = _DummyMCPToolCallSession
|
||||
mcp_conn_mod.close_multiple_mcp_toolcall_sessions = lambda _sessions: None
|
||||
monkeypatch.setitem(sys.modules, "common.mcp_tool_call_conn", mcp_conn_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
async def _default_request_json():
|
||||
return {}
|
||||
|
||||
def _get_json_result(code=0, message="success", data=None):
|
||||
return {"code": code, "message": message, "data": data}
|
||||
|
||||
def _get_data_error_result(code=102, message="Sorry! Data missing!"):
|
||||
return {"code": code, "message": message}
|
||||
|
||||
def _server_error_response(error):
|
||||
return {"code": 100, "message": repr(error)}
|
||||
|
||||
async def _get_mcp_tools(*_args, **_kwargs):
|
||||
return {}
|
||||
|
||||
def _validate_request(*_args, **_kwargs):
|
||||
def _decorator(func):
|
||||
@wraps(func)
|
||||
async def _wrapped(*func_args, **func_kwargs):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*func_args, **func_kwargs)
|
||||
return func(*func_args, **func_kwargs)
|
||||
|
||||
return _wrapped
|
||||
|
||||
return _decorator
|
||||
|
||||
api_utils_mod.get_request_json = _default_request_json
|
||||
api_utils_mod.get_json_result = _get_json_result
|
||||
api_utils_mod.get_data_error_result = _get_data_error_result
|
||||
api_utils_mod.server_error_response = _server_error_response
|
||||
api_utils_mod.validate_request = _validate_request
|
||||
api_utils_mod.get_mcp_tools = _get_mcp_tools
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
module_name = "test_mcp_server_app_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "mcp_server_app.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_mcp_desc_pagination_and_exception(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"request",
|
||||
SimpleNamespace(args={"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"}),
|
||||
)
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": []})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_servers", lambda *_args, **_kwargs: [{"id": "a"}, {"id": "b"}])
|
||||
|
||||
res = _run(module.list_mcp())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["total"] == 2
|
||||
assert res["data"]["mcp_servers"] == [{"id": "b"}]
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={}))
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": []})
|
||||
|
||||
def _raise_list(*_args, **_kwargs):
|
||||
raise RuntimeError("list explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_servers", _raise_list)
|
||||
res = _run(module.list_mcp())
|
||||
assert res["code"] == 100
|
||||
assert "list explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_detail_not_found_success_and_exception(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args={"mcp_id": "mcp-1"}))
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_or_none", lambda **_kwargs: None)
|
||||
res = module.detail()
|
||||
assert res["code"] == module.RetCode.NOT_FOUND
|
||||
|
||||
monkeypatch.setattr(
|
||||
module.MCPServerService,
|
||||
"get_or_none",
|
||||
lambda **_kwargs: _DummyMCPServer(id="mcp-1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1"),
|
||||
)
|
||||
res = module.detail()
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "mcp-1"
|
||||
|
||||
def _raise_detail(**_kwargs):
|
||||
raise RuntimeError("detail explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_or_none", _raise_detail)
|
||||
res = module.detail()
|
||||
assert res["code"] == 100
|
||||
assert "detail explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_validation_guards(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "invalid"})
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert "Unsupported MCP server type" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"name": "", "url": "http://a", "server_type": "sse"})
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert "Invalid MCP name" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (True, object()))
|
||||
_set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "sse"})
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert "Duplicated MCP server name" in res["message"]
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
|
||||
_set_request_json(monkeypatch, module, {"name": "srv", "url": "", "server_type": "sse"})
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert "Invalid url" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_create_service_paths(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
base_payload = {
|
||||
"name": "srv",
|
||||
"url": "http://server",
|
||||
"server_type": "sse",
|
||||
"headers": '{"Authorization": "x"}',
|
||||
"variables": '{"tools": {"old": 1}, "token": "abc"}',
|
||||
"timeout": "2.5",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(module, "get_uuid", lambda: "uuid-create")
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None))
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (False, None))
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert "Tenant not found" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (True, object()))
|
||||
|
||||
async def _thread_pool_tools_error(_func, _servers, _timeout):
|
||||
return None, "tools error"
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == "tools error"
|
||||
assert "Sorry! Data missing!" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
|
||||
async def _thread_pool_ok(_func, servers, _timeout):
|
||||
return {servers[0].name: [{"name": "tool_a"}, {"invalid": True}]}, None
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
|
||||
monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: False)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == "Failed to create MCP server."
|
||||
assert "Sorry! Data missing!" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: True)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "uuid-create"
|
||||
assert res["data"]["tenant_id"] == "tenant_1"
|
||||
assert res["data"]["variables"]["tools"] == {"tool_a": {"name": "tool_a"}}
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
|
||||
async def _thread_pool_raises(_func, _servers, _timeout):
|
||||
raise RuntimeError("create explode")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
|
||||
res = _run(module.create.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "create explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_validation_guards(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
existing = _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="tenant_1", variables={}, headers={})
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None))
|
||||
res = _run(module.update.__wrapped__())
|
||||
assert "Cannot find MCP server" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"})
|
||||
monkeypatch.setattr(
|
||||
module.MCPServerService,
|
||||
"get_by_id",
|
||||
lambda _mcp_id: (True, _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="other", variables={}, headers={})),
|
||||
)
|
||||
res = _run(module.update.__wrapped__())
|
||||
assert "Cannot find MCP server" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "server_type": "invalid"})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
|
||||
res = _run(module.update.__wrapped__())
|
||||
assert "Unsupported MCP server type" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "name": "a" * 256})
|
||||
res = _run(module.update.__wrapped__())
|
||||
assert "Invalid MCP name" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": ""})
|
||||
res = _run(module.update.__wrapped__())
|
||||
assert "Invalid url" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_service_paths(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
existing = _DummyMCPServer(
|
||||
id="mcp-1",
|
||||
name="srv",
|
||||
url="http://server",
|
||||
server_type="sse",
|
||||
tenant_id="tenant_1",
|
||||
variables={"tools": {"old": {"enabled": True}}, "token": "abc"},
|
||||
headers={"Authorization": "old"},
|
||||
)
|
||||
updated = _DummyMCPServer(
|
||||
id="mcp-1",
|
||||
name="srv-new",
|
||||
url="http://server-new",
|
||||
server_type="sse",
|
||||
tenant_id="tenant_1",
|
||||
variables={"tools": {"tool_a": {"name": "tool_a"}}},
|
||||
headers={"Authorization": "new"},
|
||||
)
|
||||
|
||||
base_payload = {
|
||||
"mcp_id": "mcp-1",
|
||||
"name": "srv-new",
|
||||
"url": "http://server-new",
|
||||
"server_type": "sse",
|
||||
"headers": '{"Authorization": "new"}',
|
||||
"variables": '{"tools": {"ignore": 1}, "token": "new"}',
|
||||
"timeout": "3.0",
|
||||
}
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
|
||||
|
||||
async def _thread_pool_tools_error(_func, _servers, _timeout):
|
||||
return None, "update tools error"
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error)
|
||||
res = _run(module.update.__wrapped__())
|
||||
assert res["code"] == "update tools error"
|
||||
assert "Sorry! Data missing!" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
|
||||
async def _thread_pool_ok(_func, servers, _timeout):
|
||||
return {servers[0].name: [{"name": "tool_a"}, {"bad": True}]}, None
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok)
|
||||
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False)
|
||||
res = _run(module.update.__wrapped__())
|
||||
assert "Failed to updated MCP server" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True)
|
||||
|
||||
def _get_by_id_fetch_fail(_mcp_id):
|
||||
if _get_by_id_fetch_fail.calls == 0:
|
||||
_get_by_id_fetch_fail.calls += 1
|
||||
return True, existing
|
||||
return False, None
|
||||
|
||||
_get_by_id_fetch_fail.calls = 0
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_fetch_fail)
|
||||
res = _run(module.update.__wrapped__())
|
||||
assert "Failed to fetch updated MCP server" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
|
||||
def _get_by_id_success(_mcp_id):
|
||||
if _get_by_id_success.calls == 0:
|
||||
_get_by_id_success.calls += 1
|
||||
return True, existing
|
||||
return True, updated
|
||||
|
||||
_get_by_id_success.calls = 0
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_success)
|
||||
res = _run(module.update.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id"] == "mcp-1"
|
||||
|
||||
_set_request_json(monkeypatch, module, dict(base_payload))
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing))
|
||||
|
||||
async def _thread_pool_raises(_func, _servers, _timeout):
|
||||
raise RuntimeError("update explode")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises)
|
||||
res = _run(module.update.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "update explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_rm_failure_success_and_exception(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
|
||||
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: False)
|
||||
res = _run(module.rm.__wrapped__())
|
||||
assert "Failed to delete MCP servers" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
|
||||
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: True)
|
||||
res = _run(module.rm.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"] is True
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]})
|
||||
|
||||
def _raise_rm(_ids):
|
||||
raise RuntimeError("rm explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "delete_by_ids", _raise_rm)
|
||||
res = _run(module.rm.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "rm explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_import_multiple_missing_servers_and_exception(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcpServers": {}})
|
||||
res = _run(module.import_multiple.__wrapped__())
|
||||
assert "No MCP servers provided" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcpServers": {"srv": {"type": "sse", "url": "http://x"}}, "timeout": "1"})
|
||||
|
||||
def _raise_import(**_kwargs):
|
||||
raise RuntimeError("import explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _raise_import)
|
||||
res = _run(module.import_multiple.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "import explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_import_multiple_mixed_results(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
payload = {
|
||||
"mcpServers": {
|
||||
"missing_fields": {"type": "sse"},
|
||||
"": {"type": "sse", "url": "http://empty"},
|
||||
"dup": {"type": "sse", "url": "http://dup", "authorization_token": "dup-token"},
|
||||
"tool_err": {"type": "sse", "url": "http://err"},
|
||||
"insert_fail": {"type": "sse", "url": "http://fail"},
|
||||
},
|
||||
"timeout": "3",
|
||||
}
|
||||
_set_request_json(monkeypatch, module, payload)
|
||||
|
||||
monkeypatch.setattr(module, "get_uuid", lambda: "uuid-import")
|
||||
|
||||
def _get_by_name_and_tenant(name, tenant_id):
|
||||
if name == "dup" and not _get_by_name_and_tenant.first_dup_seen:
|
||||
_get_by_name_and_tenant.first_dup_seen = True
|
||||
return True, object()
|
||||
return False, None
|
||||
|
||||
_get_by_name_and_tenant.first_dup_seen = False
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _get_by_name_and_tenant)
|
||||
|
||||
async def _thread_pool_exec(func, servers, _timeout):
|
||||
mcp_server = servers[0]
|
||||
if mcp_server.name == "tool_err":
|
||||
return None, "tool call failed"
|
||||
return {mcp_server.name: [{"name": "tool_a"}, {"invalid": True}]}, None
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec)
|
||||
|
||||
def _insert(**kwargs):
|
||||
return kwargs["name"] != "insert_fail"
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "insert", _insert)
|
||||
|
||||
res = _run(module.import_multiple.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
|
||||
results = {item["server"]: item for item in res["data"]["results"]}
|
||||
assert results["missing_fields"]["success"] is False
|
||||
assert "Missing required fields" in results["missing_fields"]["message"]
|
||||
assert results[""]["success"] is False
|
||||
assert "Invalid MCP name" in results[""]["message"]
|
||||
assert results["tool_err"]["success"] is False
|
||||
assert "tool call failed" in results["tool_err"]["message"]
|
||||
assert results["insert_fail"]["success"] is False
|
||||
assert "Failed to create MCP server" in results["insert_fail"]["message"]
|
||||
assert results["dup"]["success"] is True
|
||||
assert results["dup"]["new_name"] == "dup_0"
|
||||
assert "Renamed from 'dup' to 'dup_0' avoid duplication" == results["dup"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_export_multiple_missing_ids_success_and_exception(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": []})
|
||||
res = _run(module.export_multiple.__wrapped__())
|
||||
assert "No MCP server IDs provided" in res["message"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1", "id2", "id3"]})
|
||||
|
||||
def _get_by_id(mcp_id):
|
||||
if mcp_id == "id1":
|
||||
return True, _DummyMCPServer(
|
||||
id="id1",
|
||||
name="srv-one",
|
||||
url="http://one",
|
||||
server_type="sse",
|
||||
tenant_id="tenant_1",
|
||||
variables={"authorization_token": "tok", "tools": {"tool_a": {"enabled": True}}},
|
||||
)
|
||||
if mcp_id == "id2":
|
||||
return True, _DummyMCPServer(
|
||||
id="id2",
|
||||
name="srv-two",
|
||||
url="http://two",
|
||||
server_type="sse",
|
||||
tenant_id="other",
|
||||
variables={},
|
||||
)
|
||||
return False, None
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id)
|
||||
res = _run(module.export_multiple.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert list(res["data"]["mcpServers"].keys()) == ["srv-one"]
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"]})
|
||||
|
||||
def _raise_export(_mcp_id):
|
||||
raise RuntimeError("export explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_export)
|
||||
res = _run(module.export_multiple.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "export explode" in res["message"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_list_tools_missing_ids_success_inner_error_outer_error_and_finally_cleanup(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": []})
|
||||
res = _run(module.list_tools.__wrapped__())
|
||||
assert "No MCP server IDs provided" in res["message"]
|
||||
|
||||
server = _DummyMCPServer(
|
||||
id="id1",
|
||||
name="srv-tools",
|
||||
url="http://tools",
|
||||
server_type="sse",
|
||||
tenant_id="tenant_1",
|
||||
variables={"tools": {"tool_a": {"enabled": False}}},
|
||||
)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"})
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server))
|
||||
|
||||
close_calls = []
|
||||
|
||||
async def _thread_pool_exec_success(func, *args):
|
||||
if func is module.close_multiple_mcp_toolcall_sessions:
|
||||
close_calls.append(args[0])
|
||||
return None
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success)
|
||||
res = _run(module.list_tools.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
assert res["data"]["id1"][0]["name"] == "tool_a"
|
||||
assert res["data"]["id1"][0]["enabled"] is False
|
||||
assert res["data"]["id1"][1]["enabled"] is True
|
||||
assert close_calls and len(close_calls[-1]) == 1
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"})
|
||||
close_calls_inner = []
|
||||
|
||||
async def _thread_pool_exec_inner_error(func, *args):
|
||||
if func is module.close_multiple_mcp_toolcall_sessions:
|
||||
close_calls_inner.append(args[0])
|
||||
return None
|
||||
raise RuntimeError("inner tools explode")
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error)
|
||||
res = _run(module.list_tools.__wrapped__())
|
||||
assert res["code"] == 102
|
||||
assert "MCP list tools error" in res["message"]
|
||||
assert close_calls_inner and len(close_calls_inner[-1]) == 1
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_ids": ["id1"], "timeout": "2.0"})
|
||||
close_calls_outer = []
|
||||
|
||||
def _raise_get_by_id(_mcp_id):
|
||||
raise RuntimeError("outer explode")
|
||||
|
||||
monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_get_by_id)
|
||||
|
||||
async def _thread_pool_exec_outer(func, *args):
|
||||
if func is module.close_multiple_mcp_toolcall_sessions:
|
||||
close_calls_outer.append(args[0])
|
||||
return None
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_outer)
|
||||
res = _run(module.list_tools.__wrapped__())
|
||||
assert res["code"] == 100
|
||||
assert "outer explode" in res["message"]
|
||||
assert close_calls_outer
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_test_tool_missing_mcp_id(monkeypatch):
|
||||
module = _load_mcp_server_app(monkeypatch)
|
||||
|
||||
_set_request_json(monkeypatch, module, {"mcp_id": "", "tool_name": "tool_a", "arguments": {"x": 1}})
|
||||
res = _run(module.test_tool.__wrapped__())
|
||||
assert "No MCP server ID provided" in res["message"]
|
||||
@ -20,8 +20,6 @@ import pytest
|
||||
from test_web_api.common import create_memory
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from hypothesis import example, given, settings
|
||||
from utils.hypothesis_utils import valid_names
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@ -42,9 +40,7 @@ class TestAuthorization:
|
||||
|
||||
class TestMemoryCreate:
|
||||
@pytest.mark.p1
|
||||
@given(name=valid_names())
|
||||
@example("d" * 128)
|
||||
@settings(max_examples=20)
|
||||
@pytest.mark.parametrize("name", ["test_memory_name", "d" * 128])
|
||||
def test_name(self, WebApiAuth, name):
|
||||
payload = {
|
||||
"name": name,
|
||||
@ -79,7 +75,7 @@ class TestMemoryCreate:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@given(name=valid_names())
|
||||
@pytest.mark.parametrize("name", ["invalid_type_name", "memory_alpha"])
|
||||
def test_type_invalid(self, WebApiAuth, name):
|
||||
payload = {
|
||||
"name": name,
|
||||
|
||||
@ -13,14 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from test_web_api.common import update_memory
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from hypothesis import HealthCheck, example, given, settings
|
||||
from utils import encode_avatar
|
||||
from utils.file_utils import create_image_file
|
||||
from utils.hypothesis_utils import valid_names
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@ -42,15 +42,14 @@ class TestAuthorization:
|
||||
class TestMemoryUpdate:
|
||||
|
||||
@pytest.mark.p1
|
||||
@given(name=valid_names())
|
||||
@example("f" * 128)
|
||||
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@pytest.mark.parametrize("name", ["updated_memory", "f" * 128])
|
||||
def test_name(self, WebApiAuth, add_memory_func, name):
|
||||
memory_ids = add_memory_func
|
||||
payload = {"name": name}
|
||||
res = update_memory(WebApiAuth, memory_ids[0], payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["name"] == name, res
|
||||
pattern = rf"^{re.escape(name)}(?:\(\d+\))?$"
|
||||
assert re.match(pattern, res["data"]["name"]), res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -71,6 +71,20 @@ Are you asking about the fruit itself, or its use in a specific context?
|
||||
assert message["agent_id"] == agent_id, message
|
||||
assert message["session_id"] == session_id, message
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_add_message_invalid_memory_id(self, WebApiAuth):
|
||||
message_payload = {
|
||||
"memory_id": ["missing_memory_id"],
|
||||
"agent_id": uuid.uuid4().hex,
|
||||
"session_id": uuid.uuid4().hex,
|
||||
"user_id": "",
|
||||
"user_input": "what is pineapple?",
|
||||
"agent_response": "pineapple response",
|
||||
}
|
||||
res = add_message(WebApiAuth, message_payload)
|
||||
assert res["code"] == 500, res
|
||||
assert "Some messages failed to add" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_empty_multiple_type_memory")
|
||||
class TestAddMultipleTypeMessage:
|
||||
|
||||
@ -15,8 +15,9 @@
|
||||
#
|
||||
import random
|
||||
import pytest
|
||||
import requests
|
||||
from test_web_api.common import forget_message, list_memory_message, get_message_content
|
||||
from configs import INVALID_API_TOKEN
|
||||
from configs import HOST_ADDRESS, INVALID_API_TOKEN, VERSION
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
@ -52,3 +53,17 @@ class TestForgetMessage:
|
||||
forgot_message_res = get_message_content(WebApiAuth, memory_id, message["message_id"])
|
||||
assert forgot_message_res["code"] == 0, forgot_message_res
|
||||
assert forgot_message_res["data"]["forget_at"] not in ["-", ""], forgot_message_res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_forget_message_invalid_memory_id(self, WebApiAuth):
|
||||
res = forget_message(WebApiAuth, "missing_memory_id", 1)
|
||||
assert res["code"] == 404, res
|
||||
assert "not found" in res["message"].lower(), res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_forget_message_invalid_message_id(self, WebApiAuth):
|
||||
memory_id = self.memory_id
|
||||
url = f"{HOST_ADDRESS}/api/{VERSION}/messages/{memory_id}:invalid_message_id"
|
||||
res = requests.delete(url=url, headers={"Content-Type": "application/json"}, auth=WebApiAuth).json()
|
||||
assert res["code"] == 500, res
|
||||
assert "Internal server error" in res["message"], res
|
||||
|
||||
@ -49,3 +49,16 @@ class TestGetMessageContent:
|
||||
for field in ["content", "content_embed"]:
|
||||
assert field in content_res["data"]
|
||||
assert content_res["data"][field] is not None, content_res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_message_content_invalid_memory_id(self, WebApiAuth):
|
||||
res = get_message_content(WebApiAuth, "missing_memory_id", 1)
|
||||
assert res["code"] == 404, res
|
||||
assert "not found" in res["message"].lower(), res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_message_content_invalid_message_id(self, WebApiAuth):
|
||||
memory_id = self.memory_id
|
||||
res = get_message_content(WebApiAuth, memory_id, 999999999)
|
||||
assert res["code"] == 404, res
|
||||
assert "not found" in res["message"].lower(), res
|
||||
|
||||
@ -66,3 +66,15 @@ class TestGetRecentMessage:
|
||||
for message in res["data"]:
|
||||
assert message["session_id"] == session_id, message
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_recent_messages_missing_memory_id(self, WebApiAuth):
|
||||
res = get_recent_message(WebApiAuth, params={})
|
||||
assert res["code"] == 101, res
|
||||
assert "memory_ids is required" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_recent_messages_csv_memory_ids(self, WebApiAuth):
|
||||
memory_id = self.memory_id
|
||||
res = get_recent_message(WebApiAuth, params={"memory_id": f"{memory_id},{memory_id}"})
|
||||
assert res["code"] == 0, res
|
||||
assert isinstance(res["data"], list), res
|
||||
|
||||
@ -0,0 +1,151 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import inspect
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class _AwaitableValue:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def __await__(self):
|
||||
async def _co():
|
||||
return self._value
|
||||
|
||||
return _co().__await__()
|
||||
|
||||
|
||||
class _DummyArgs(dict):
|
||||
def getlist(self, key):
|
||||
value = self.get(key)
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return [value]
|
||||
|
||||
|
||||
class _DummyMemoryApiService:
|
||||
async def add_message(self, *_args, **_kwargs):
|
||||
return True, "ok"
|
||||
|
||||
async def get_messages(self, *_args, **_kwargs):
|
||||
return []
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _load_memory_routes_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
apps_mod = ModuleType("api.apps")
|
||||
apps_mod.__path__ = [str(repo_root / "api" / "apps")]
|
||||
apps_mod.current_user = SimpleNamespace(id="user-1")
|
||||
apps_mod.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", apps_mod)
|
||||
|
||||
services_mod = ModuleType("api.apps.services")
|
||||
services_mod.memory_api_service = _DummyMemoryApiService()
|
||||
monkeypatch.setitem(sys.modules, "api.apps.services", services_mod)
|
||||
|
||||
module_name = "test_message_routes_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "restful_apis" / "memory_api.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _set_request_json(monkeypatch, module, payload):
|
||||
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload)))
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_add_message_partial_failure_branch(monkeypatch):
|
||||
module = _load_memory_routes_module(monkeypatch)
|
||||
|
||||
_set_request_json(
|
||||
monkeypatch,
|
||||
module,
|
||||
{
|
||||
"memory_id": ["memory-1"],
|
||||
"agent_id": "agent-1",
|
||||
"session_id": "session-1",
|
||||
"user_input": "hello",
|
||||
"agent_response": "world",
|
||||
},
|
||||
)
|
||||
|
||||
async def _add_message(_memory_ids, _message_dict):
|
||||
return False, "cannot enqueue"
|
||||
|
||||
monkeypatch.setattr(module.memory_api_service, "add_message", _add_message)
|
||||
|
||||
res = _run(inspect.unwrap(module.add_message)())
|
||||
assert res["code"] == module.RetCode.SERVER_ERROR, res
|
||||
assert "Some messages failed to add" in res["message"], res
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_messages_csv_and_missing_memory_ids(monkeypatch):
|
||||
module = _load_memory_routes_module(monkeypatch)
|
||||
|
||||
monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs({})))
|
||||
res = _run(inspect.unwrap(module.get_messages)())
|
||||
assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
|
||||
assert "memory_ids is required." in res["message"], res
|
||||
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"request",
|
||||
SimpleNamespace(args=_DummyArgs({"memory_id": "m1,m2", "agent_id": "a1", "session_id": "s1", "limit": "5"})),
|
||||
)
|
||||
|
||||
async def _get_messages(memory_ids, agent_id, session_id, limit):
|
||||
assert memory_ids == ["m1", "m2"]
|
||||
assert agent_id == "a1"
|
||||
assert session_id == "s1"
|
||||
assert limit == 5
|
||||
return [{"message_id": 1}]
|
||||
|
||||
monkeypatch.setattr(module.memory_api_service, "get_messages", _get_messages)
|
||||
res = _run(inspect.unwrap(module.get_messages)())
|
||||
assert res["code"] == module.RetCode.SUCCESS, res
|
||||
assert isinstance(res["data"], list), res
|
||||
@ -80,3 +80,23 @@ class TestSearchMessage:
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) > 0
|
||||
assert len(res["data"]) <= params["top_n"]
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_query_missing_query(self, WebApiAuth):
|
||||
memory_id = self.memory_id
|
||||
res = search_message(WebApiAuth, {"memory_id": memory_id})
|
||||
assert res["code"] in [100, 500], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_query_missing_memory_id(self, WebApiAuth):
|
||||
res = search_message(WebApiAuth, {"query": "what is coriander"})
|
||||
assert res["code"] == 0, res
|
||||
assert isinstance(res["data"], list), res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_query_with_csv_memory_ids(self, WebApiAuth):
|
||||
memory_id = self.memory_id
|
||||
query = "Coriander is a versatile herb."
|
||||
res = search_message(WebApiAuth, {"memory_id": f"{memory_id},{memory_id}", "query": query})
|
||||
assert res["code"] == 0, res
|
||||
assert isinstance(res["data"], list), res
|
||||
|
||||
@ -16,9 +16,11 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from test_web_api.common import update_message_status, list_memory_message, get_message_content
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from configs import HOST_ADDRESS, VERSION
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@ -73,3 +75,34 @@ class TestUpdateMessageStatus:
|
||||
res = get_message_content(WebApiAuth, memory_id, message["message_id"])
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["status"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_invalid_status_type(self, WebApiAuth):
|
||||
memory_id = self.memory_id
|
||||
list_res = list_memory_message(WebApiAuth, memory_id)
|
||||
assert list_res["code"] == 0, list_res
|
||||
message_id = list_res["data"]["messages"]["message_list"][0]["message_id"]
|
||||
|
||||
url = f"{HOST_ADDRESS}/api/{VERSION}/messages/{memory_id}:{message_id}"
|
||||
res = requests.put(url=url, headers={"Content-Type": "application/json"}, auth=WebApiAuth, json={"status": "false"}).json()
|
||||
assert res["code"] == 101, res
|
||||
assert "Status must be a boolean." in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_invalid_memory_id(self, WebApiAuth):
|
||||
res = update_message_status(WebApiAuth, "missing_memory_id", 1, False)
|
||||
assert res["code"] == 404, res
|
||||
assert "not found" in res["message"].lower(), res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_update_invalid_message_id(self, WebApiAuth):
|
||||
memory_id = self.memory_id
|
||||
url = f"{HOST_ADDRESS}/api/{VERSION}/messages/{memory_id}:invalid_message_id"
|
||||
res = requests.put(
|
||||
url=url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
auth=WebApiAuth,
|
||||
json={"status": True},
|
||||
).json()
|
||||
assert res["code"] == 500, res
|
||||
assert "Internal server error" in res["message"], res
|
||||
|
||||
@ -13,6 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
from common import plugin_llm_tools
|
||||
from configs import INVALID_API_TOKEN
|
||||
@ -40,3 +45,54 @@ class TestPluginTools:
|
||||
res = plugin_llm_tools(WebApiAuth)
|
||||
assert res["code"] == 0, res
|
||||
assert isinstance(res["data"], list), res
|
||||
|
||||
|
||||
class _DummyManager:
|
||||
def route(self, *_args, **_kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def _load_plugin_app(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
stub_apps = ModuleType("api.apps")
|
||||
stub_apps.login_required = lambda func: func
|
||||
monkeypatch.setitem(sys.modules, "api.apps", stub_apps)
|
||||
|
||||
stub_plugin = ModuleType("agent.plugin")
|
||||
|
||||
class _StubGlobalPluginManager:
|
||||
@staticmethod
|
||||
def get_llm_tools():
|
||||
return []
|
||||
|
||||
stub_plugin.GlobalPluginManager = _StubGlobalPluginManager
|
||||
monkeypatch.setitem(sys.modules, "agent.plugin", stub_plugin)
|
||||
|
||||
module_path = Path(__file__).resolve().parents[4] / "api" / "apps" / "plugin_app.py"
|
||||
spec = importlib.util.spec_from_file_location("test_plugin_app_unit", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.manager = _DummyManager()
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_llm_tools_metadata_shape_unit(monkeypatch):
|
||||
module = _load_plugin_app(monkeypatch)
|
||||
|
||||
class _DummyTool:
|
||||
def get_metadata(self):
|
||||
return {"name": "dummy", "description": "test"}
|
||||
|
||||
monkeypatch.setattr(module.GlobalPluginManager, "get_llm_tools", staticmethod(lambda: [_DummyTool()]))
|
||||
res = module.llm_tools()
|
||||
assert res["code"] == 0
|
||||
assert isinstance(res["data"], list)
|
||||
assert res["data"][0]["name"] == "dummy"
|
||||
assert res["data"][0]["description"] == "test"
|
||||
|
||||
@ -0,0 +1,242 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
|
||||
|
||||
|
||||
class _DummyAPIToken:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
|
||||
class _DummyUserService:
|
||||
@staticmethod
|
||||
def query(**_kwargs):
|
||||
return []
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _load_apps_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
common_pkg = ModuleType("common")
|
||||
common_pkg.__path__ = [str(repo_root / "common")]
|
||||
monkeypatch.setitem(sys.modules, "common", common_pkg)
|
||||
|
||||
settings_mod = ModuleType("common.settings")
|
||||
settings_mod.SECRET_KEY = "test-secret-key"
|
||||
settings_mod.init_settings = lambda: None
|
||||
settings_mod.decrypt_database_config = lambda name=None: {}
|
||||
monkeypatch.setitem(sys.modules, "common.settings", settings_mod)
|
||||
common_pkg.settings = settings_mod
|
||||
|
||||
db_models_mod = ModuleType("api.db.db_models")
|
||||
db_models_mod.APIToken = _DummyAPIToken
|
||||
db_models_mod.close_connection = lambda: None
|
||||
monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod)
|
||||
|
||||
services_mod = ModuleType("api.db.services")
|
||||
services_mod.UserService = _DummyUserService
|
||||
monkeypatch.setitem(sys.modules, "api.db.services", services_mod)
|
||||
|
||||
commands_mod = ModuleType("api.utils.commands")
|
||||
commands_mod.register_commands = lambda _app: None
|
||||
monkeypatch.setitem(sys.modules, "api.utils.commands", commands_mod)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
|
||||
def _get_json_result(code=0, message="success", data=None):
|
||||
return {"code": code, "message": message, "data": data}
|
||||
|
||||
def _server_error_response(error):
|
||||
return {"code": 100, "message": repr(error)}
|
||||
|
||||
api_utils_mod.get_json_result = _get_json_result
|
||||
api_utils_mod.server_error_response = _server_error_response
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
module_name = "test_apps_init_unit_module"
|
||||
module_path = repo_root / "api" / "apps" / "__init__.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, module_name, module)
|
||||
|
||||
monkeypatch.setattr(Path, "glob", lambda self, _pattern: [])
|
||||
spec.loader.exec_module(module)
|
||||
return module.app, module
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_module_init_and_unauthorized_message_variants(monkeypatch):
|
||||
_quart_app, apps_module = _load_apps_module(monkeypatch)
|
||||
|
||||
assert apps_module.client_urls_prefix == []
|
||||
|
||||
class _BrokenRepr:
|
||||
def __repr__(self):
|
||||
raise RuntimeError("repr explode")
|
||||
|
||||
class _ExactUnauthorizedRepr:
|
||||
def __repr__(self):
|
||||
return apps_module.UNAUTHORIZED_MESSAGE
|
||||
|
||||
class _Unauthorized401Repr:
|
||||
def __repr__(self):
|
||||
return "Unauthorized 401 from upstream"
|
||||
|
||||
class _OtherRepr:
|
||||
def __repr__(self):
|
||||
return "Forbidden 403"
|
||||
|
||||
assert apps_module._unauthorized_message(None) == apps_module.UNAUTHORIZED_MESSAGE
|
||||
assert apps_module._unauthorized_message(_BrokenRepr()) == apps_module.UNAUTHORIZED_MESSAGE
|
||||
assert apps_module._unauthorized_message(_ExactUnauthorizedRepr()) == apps_module.UNAUTHORIZED_MESSAGE
|
||||
assert apps_module._unauthorized_message(_Unauthorized401Repr()) == "Unauthorized 401 from upstream"
|
||||
assert apps_module._unauthorized_message(_OtherRepr()) == apps_module.UNAUTHORIZED_MESSAGE
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_load_user_token_edge_cases(monkeypatch):
|
||||
quart_app, apps_module = _load_apps_module(monkeypatch)
|
||||
|
||||
user_with_empty_token = SimpleNamespace(email="empty@example.com", access_token="")
|
||||
|
||||
async def _case():
|
||||
async with quart_app.test_request_context("/", headers={"Authorization": "token"}):
|
||||
monkeypatch.setattr(apps_module.Serializer, "loads", lambda _self, _auth: "")
|
||||
assert apps_module._load_user() is None
|
||||
|
||||
async with quart_app.test_request_context("/", headers={"Authorization": "token"}):
|
||||
monkeypatch.setattr(apps_module.Serializer, "loads", lambda _self, _auth: "short-token")
|
||||
assert apps_module._load_user() is None
|
||||
|
||||
async with quart_app.test_request_context("/", headers={"Authorization": "token"}):
|
||||
monkeypatch.setattr(apps_module.Serializer, "loads", lambda _self, _auth: "a" * 32)
|
||||
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kwargs: [user_with_empty_token])
|
||||
assert apps_module._load_user() is None
|
||||
|
||||
_run(_case())
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_load_user_api_token_fallback_and_fallback_exception(monkeypatch, caplog):
|
||||
quart_app, apps_module = _load_apps_module(monkeypatch)
|
||||
|
||||
def _raise_decode(_self, _auth):
|
||||
raise RuntimeError("decode failed")
|
||||
|
||||
monkeypatch.setattr(apps_module.Serializer, "loads", _raise_decode)
|
||||
|
||||
fallback_user_empty_token = SimpleNamespace(email="fallback@example.com", access_token="")
|
||||
|
||||
async def _case():
|
||||
monkeypatch.setattr(apps_module.APIToken, "query", lambda **_kwargs: [SimpleNamespace(tenant_id="tenant-1")])
|
||||
monkeypatch.setattr(apps_module.UserService, "query", lambda **_kwargs: [fallback_user_empty_token])
|
||||
async with quart_app.test_request_context("/", headers={"Authorization": "Bearer api-token"}):
|
||||
assert apps_module._load_user() is None
|
||||
|
||||
def _raise_api_token(**_kwargs):
|
||||
raise RuntimeError("api token fallback failed")
|
||||
|
||||
monkeypatch.setattr(apps_module.APIToken, "query", _raise_api_token)
|
||||
async with quart_app.test_request_context("/", headers={"Authorization": "Bearer api-token"}):
|
||||
with caplog.at_level(logging.WARNING):
|
||||
assert apps_module._load_user() is None
|
||||
|
||||
_run(_case())
|
||||
assert "api token fallback failed" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_login_required_timing_and_login_user_inactive(monkeypatch, caplog):
|
||||
quart_app, apps_module = _load_apps_module(monkeypatch)
|
||||
|
||||
monkeypatch.setenv("RAGFLOW_API_TIMING", "1")
|
||||
monkeypatch.setattr(apps_module, "current_user", SimpleNamespace(id="tenant-1"))
|
||||
|
||||
@apps_module.login_required
|
||||
async def _timed_handler():
|
||||
return {"ok": True}
|
||||
|
||||
async def _case():
|
||||
async with quart_app.test_request_context("/timed"):
|
||||
with caplog.at_level(logging.INFO):
|
||||
assert await _timed_handler() == {"ok": True}
|
||||
|
||||
inactive_user = SimpleNamespace(id="user-1", is_active=False)
|
||||
assert apps_module.login_user(inactive_user) is False
|
||||
|
||||
_run(_case())
|
||||
assert "api_timing login_required" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_logout_user_not_found_and_unauthorized_handlers(monkeypatch):
|
||||
quart_app, apps_module = _load_apps_module(monkeypatch)
|
||||
|
||||
async def _case():
|
||||
async with quart_app.test_request_context("/logout", headers={"Cookie": "remember_token=abc"}):
|
||||
from quart import session
|
||||
|
||||
session["_user_id"] = "user-1"
|
||||
session["_fresh"] = True
|
||||
session["_id"] = "session-id"
|
||||
session["_remember_seconds"] = 5
|
||||
|
||||
assert apps_module.logout_user() is True
|
||||
assert "_user_id" not in session
|
||||
assert "_fresh" not in session
|
||||
assert "_id" not in session
|
||||
assert session.get("_remember") == "clear"
|
||||
assert "_remember_seconds" not in session
|
||||
|
||||
async with quart_app.test_request_context("/missing/path"):
|
||||
not_found_resp, status = await apps_module.not_found(RuntimeError("missing"))
|
||||
assert status == apps_module.RetCode.NOT_FOUND
|
||||
payload = await not_found_resp.get_json()
|
||||
assert payload["code"] == apps_module.RetCode.NOT_FOUND
|
||||
assert payload["error"] == "Not Found"
|
||||
assert "Not Found:" in payload["message"]
|
||||
|
||||
async with quart_app.test_request_context("/protected"):
|
||||
@apps_module.login_required
|
||||
async def _protected():
|
||||
return {"ok": True}
|
||||
|
||||
monkeypatch.setattr(apps_module, "current_user", None)
|
||||
with pytest.raises(apps_module.QuartAuthUnauthorized) as exc_info:
|
||||
await _protected()
|
||||
|
||||
quart_payload, quart_status = await apps_module.unauthorized_quart_auth(exc_info.value)
|
||||
assert quart_status == apps_module.RetCode.UNAUTHORIZED
|
||||
assert quart_payload["code"] == apps_module.RetCode.UNAUTHORIZED
|
||||
|
||||
werk_payload, werk_status = await apps_module.unauthorized_werkzeug(WerkzeugUnauthorized("Unauthorized 401"))
|
||||
assert werk_status == apps_module.RetCode.UNAUTHORIZED
|
||||
assert werk_payload["code"] == apps_module.RetCode.UNAUTHORIZED
|
||||
|
||||
_run(_case())
|
||||
Reference in New Issue
Block a user