From 157208ab1e01ddc530a79982bd2ce2ddb5eb720e Mon Sep 17 00:00:00 2001 From: mahammadasim <135003320+mahammadasim@users.noreply.github.com> Date: Thu, 12 Mar 2026 13:04:20 +0530 Subject: [PATCH] =?UTF-8?q?test:=20added=20test=20for=20services=20of=20op?= =?UTF-8?q?s,=20summary,=20vector,=20website=20and=20ji=E2=80=A6=20(#32893?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: akashseth-ifp --- .../auth/test_jina_auth_standalone_module.py | 157 ++ .../unit_tests/services/test_ops_service.py | 381 +++++ .../services/test_summary_index_service.py | 1329 +++++++++++++++++ .../services/test_vector_service.py | 704 +++++++++ .../services/test_website_service.py | 718 +++++++++ 5 files changed, 3289 insertions(+) create mode 100644 api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py create mode 100644 api/tests/unit_tests/services/test_ops_service.py create mode 100644 api/tests/unit_tests/services/test_summary_index_service.py create mode 100644 api/tests/unit_tests/services/test_vector_service.py create mode 100644 api/tests/unit_tests/services/test_website_service.py diff --git a/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py new file mode 100644 index 0000000000..c2fcd71875 --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType +from unittest.mock import MagicMock + +import httpx +import pytest + + +@pytest.fixture(scope="module") +def jina_module() -> ModuleType: + """ + Load `api/services/auth/jina.py` as a standalone module. + + This repo contains both `services/auth/jina.py` and a package at + `services/auth/jina/`, so importing `services.auth.jina` can be ambiguous. + """ + + module_path = Path(__file__).resolve().parents[4] / "services" / "auth" / "jina.py" + # Use a stable module name so pytest-cov can target it with `--cov=services.auth.jina_file`. + spec = importlib.util.spec_from_file_location("services.auth.jina_file", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _credentials(api_key: str | None = "test_api_key_123", auth_type: str = "bearer") -> dict: + config: dict = {} if api_key is None else {"api_key": api_key} + return {"auth_type": auth_type, "config": config} + + +def test_init_valid_bearer_credentials(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials()) + assert auth.api_key == "test_api_key_123" + assert auth.credentials["auth_type"] == "bearer" + + +def test_init_rejects_invalid_auth_type(jina_module: ModuleType) -> None: + with pytest.raises(ValueError, match="Invalid auth type.*Bearer"): + jina_module.JinaAuth(_credentials(auth_type="basic")) + + +@pytest.mark.parametrize("credentials", [{"auth_type": "bearer", "config": {}}, {"auth_type": "bearer"}]) +def test_init_requires_api_key(jina_module: ModuleType, credentials: dict) -> None: + with pytest.raises(ValueError, match="No API key provided"): + jina_module.JinaAuth(credentials) + + +def test_prepare_headers_includes_bearer_api_key(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + assert auth._prepare_headers() == {"Content-Type": "application/json", "Authorization": "Bearer k"} + + +def test_post_request_calls_httpx(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + post_mock = MagicMock(name="httpx.post") + monkeypatch.setattr(jina_module.httpx, "post", post_mock) + + auth._post_request("https://r.jina.ai", {"url": "https://example.com"}, {"h": "v"}) + post_mock.assert_called_once_with("https://r.jina.ai", headers={"h": "v"}, json={"url": "https://example.com"}) + + +def test_validate_credentials_success(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + + response = MagicMock() + response.status_code = 200 + post_mock = MagicMock(return_value=response) + monkeypatch.setattr(jina_module.httpx, "post", post_mock) + + assert auth.validate_credentials() is True + post_mock.assert_called_once_with( + "https://r.jina.ai", + headers={"Content-Type": "application/json", "Authorization": "Bearer k"}, + json={"url": "https://example.com"}, + ) + + +def test_validate_credentials_non_200_raises_via_handle_error( + jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch +) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + + response = MagicMock() + response.status_code = 402 + response.json.return_value = {"error": "Payment required"} + monkeypatch.setattr(jina_module.httpx, "post", MagicMock(return_value=response)) + + with pytest.raises(Exception, match="Status code: 402.*Payment required"): + auth.validate_credentials() + + +@pytest.mark.parametrize("status_code", [402, 409, 500]) +def test_handle_error_statuses_use_response_json(jina_module: ModuleType, status_code: int) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = status_code + response.json.return_value = {"error": "boom"} + + with pytest.raises(Exception, match=f"Status code: {status_code}.*boom"): + auth._handle_error(response) + + +def test_handle_error_statuses_default_unknown_error(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 402 + response.json.return_value = {} + + with pytest.raises(Exception, match="Unknown error occurred"): + auth._handle_error(response) + + +def test_handle_error_with_text_json_body(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 403 + response.text = '{"error": "Forbidden"}' + + with pytest.raises(Exception, match="Status code: 403.*Forbidden"): + auth._handle_error(response) + + +def test_handle_error_with_text_json_body_missing_error(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 403 + response.text = "{}" + + with pytest.raises(Exception, match="Unknown error occurred"): + auth._handle_error(response) + + +def test_handle_error_without_text_raises_unexpected(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 404 + response.text = "" + + with pytest.raises(Exception, match="Unexpected error occurred.*404"): + auth._handle_error(response) + + +def test_validate_credentials_propagates_network_errors( + jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch +) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + monkeypatch.setattr(jina_module.httpx, "post", MagicMock(side_effect=httpx.ConnectError("boom"))) + + with pytest.raises(httpx.ConnectError, match="boom"): + auth.validate_credentials() diff --git a/api/tests/unit_tests/services/test_ops_service.py b/api/tests/unit_tests/services/test_ops_service.py new file mode 100644 index 0000000000..ab7b473790 --- /dev/null +++ b/api/tests/unit_tests/services/test_ops_service.py @@ -0,0 +1,381 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.config_entity import TracingProviderEnum +from models.model import App, TraceAppConfig +from services.ops_service import OpsService + + +class TestOpsService: + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.get_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + mock_db.session.query.assert_called_with(TraceAppConfig) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, None] + + # Act + result = OpsService.get_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + assert mock_db.session.query.call_count == 2 + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = None + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + # Act & Assert + with pytest.raises(ValueError, match="Tracing config cannot be None."): + OpsService.get_tracing_app_config("app_id", "arize") + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + ("provider", "default_url"), + [ + ("arize", "https://app.arize.com/"), + ("phoenix", "https://app.phoenix.arize.com/projects/"), + ("langsmith", "https://smith.langchain.com/"), + ("opik", "https://www.comet.com/opik/"), + ("weave", "https://wandb.ai/"), + ("aliyun", "https://arms.console.aliyun.com/"), + ("tencent", "https://console.cloud.tencent.com/apm"), + ("mlflow", "http://localhost:5000/"), + ("databricks", "https://www.databricks.com/"), + ], + ) + def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} + mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + + # Act + result = OpsService.get_tracing_app_config("app_id", provider) + + # Assert + assert result["tracing_config"]["project_url"] == default_url + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + "provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"] + ) + def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} + mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url" + + # Act + result = OpsService.get_tracing_app_config("app_id", provider) + + # Assert + assert result["tracing_config"]["project_url"] == "success_url" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" + + # Act + result = OpsService.get_tracing_app_config("app_id", "langfuse") + + # Assert + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + + # Act + result = OpsService.get_tracing_app_config("app_id", "langfuse") + + # Assert + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): + # Act + result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {}) + + # Assert + assert result == {"error": "Invalid tracing provider: invalid_provider"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.LANGFUSE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"}) + + # Assert + assert result == {"error": "Invalid Credentials"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + ("provider", "config"), + [ + (TracingProviderEnum.ARIZE, {}), + (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), + (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), + (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), + ], + ) + def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config): + # Arrange + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, config) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.LANGFUSE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + + # Act + result = OpsService.create_tracing_app_config( + "app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"} + ) + + # Assert + assert result == {"result": "success"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, None] + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + + # Act + # 'project' is in other_keys for Arize + # provide an empty string for the project in the tracing_config + # create_tracing_app_config will replace it with the default from the model + result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""}) + + # Assert + assert result == {"result": "success"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url" + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"} + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result == {"result": "success"} + mock_db.session.add.assert_called() + mock_db.session.commit.assert_called() + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): + # Act & Assert + with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): + OpsService.update_tracing_app_config("app_id", "invalid_provider", {}) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, None] + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + + # Act & Assert + with pytest.raises(ValueError, match="Invalid Credentials"): + OpsService.update_tracing_app_config("app_id", provider, {}) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + current_config.to_dict.return_value = {"some": "data"} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result == {"some": "data"} + mock_db.session.commit.assert_called_once() + + @patch("services.ops_service.db") + def test_delete_tracing_app_config_no_config(self, mock_db): + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.delete_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + + @patch("services.ops_service.db") + def test_delete_tracing_app_config_success(self, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.return_value = trace_config + + # Act + result = OpsService.delete_tracing_app_config("app_id", "arize") + + # Assert + assert result is True + mock_db.session.delete.assert_called_with(trace_config) + mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py new file mode 100644 index 0000000000..c7e1fed21f --- /dev/null +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -0,0 +1,1329 @@ +"""Unit tests for services.summary_index_service.""" + +from __future__ import annotations + +import sys +from dataclasses import dataclass +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import services.summary_index_service as summary_module +from services.summary_index_service import SummaryIndexService + + +@dataclass(frozen=True) +class _SessionContext: + session: MagicMock + + def __enter__(self) -> MagicMock: + return self.session + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock: + dataset = MagicMock(name="dataset") + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = indexing_technique + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding" + return dataset + + +def _segment(*, has_document: bool = True) -> MagicMock: + segment = MagicMock(name="segment") + segment.id = "seg-1" + segment.document_id = "doc-1" + segment.dataset_id = "dataset-1" + segment.content = "hello world" + segment.enabled = True + segment.status = "completed" + segment.position = 1 + if has_document: + doc = MagicMock(name="document") + doc.doc_language = "en" + doc.doc_form = "text_model" + segment.document = doc + else: + segment.document = None + return segment + + +def _summary_record(*, summary_content: str = "summary", node_id: str | None = None) -> MagicMock: + record = MagicMock(spec=summary_module.DocumentSegmentSummary, name="summary_record") + record.id = "sum-1" + record.dataset_id = "dataset-1" + record.document_id = "doc-1" + record.chunk_id = "seg-1" + record.summary_content = summary_content + record.summary_index_node_id = node_id + record.summary_index_node_hash = None + record.tokens = None + record.status = "generating" + record.error = None + record.enabled = True + record.created_at = datetime(2024, 1, 1, tzinfo=UTC) + record.updated_at = datetime(2024, 1, 1, tzinfo=UTC) + record.disabled_at = None + record.disabled_by = None + return record + + +def test_generate_summary_for_segment_passes_document_language(monkeypatch: pytest.MonkeyPatch) -> None: + usage = MagicMock() + usage.total_tokens = 10 + usage.prompt_tokens = 3 + usage.completion_tokens = 7 + + paragraph_module = SimpleNamespace( + ParagraphIndexProcessor=SimpleNamespace(generate_summary=MagicMock(return_value=("sum", usage))) + ) + monkeypatch.setitem( + sys.modules, + "core.rag.index_processor.processor.paragraph_index_processor", + paragraph_module, + ) + + segment = _segment(has_document=True) + dataset = _dataset() + + content, got_usage = SummaryIndexService.generate_summary_for_segment(segment, dataset, {"a": 1}) + assert content == "sum" + assert got_usage is usage + + paragraph_module.ParagraphIndexProcessor.generate_summary.assert_called_once() + _, kwargs = paragraph_module.ParagraphIndexProcessor.generate_summary.call_args + assert kwargs["document_language"] == "en" + + +def test_generate_summary_for_segment_raises_when_empty(monkeypatch: pytest.MonkeyPatch) -> None: + paragraph_module = SimpleNamespace( + ParagraphIndexProcessor=SimpleNamespace(generate_summary=MagicMock(return_value=("", MagicMock()))) + ) + monkeypatch.setitem( + sys.modules, + "core.rag.index_processor.processor.paragraph_index_processor", + paragraph_module, + ) + + with pytest.raises(ValueError, match="Generated summary is empty"): + SummaryIndexService.generate_summary_for_segment(_segment(), _dataset(), {"a": 1}) + + +def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytest.MonkeyPatch) -> None: + existing = _summary_record(summary_content="old", node_id="n1") + existing.enabled = False + existing.disabled_at = datetime(2024, 1, 1) + existing.disabled_by = "u" + + session = MagicMock(name="session") + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = existing + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + segment = _segment() + dataset = _dataset() + + result = SummaryIndexService.create_summary_record(segment, dataset, "new", status="generating") + assert result is existing + assert existing.summary_content == "new" + assert existing.status == "generating" + assert existing.enabled is True + assert existing.disabled_at is None + assert existing.disabled_by is None + assert existing.error is None + session.add.assert_called_once_with(existing) + session.flush.assert_called_once() + + +def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock(name="session") + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + record = SummaryIndexService.create_summary_record(_segment(), _dataset(), "new", status="generating") + assert record.dataset_id == "dataset-1" + assert record.chunk_id == "seg-1" + assert record.summary_content == "new" + assert record.enabled is True + session.add.assert_called_once() + session.flush.assert_called_once() + + +def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: + vector_cls = MagicMock() + monkeypatch.setattr(summary_module, "Vector", vector_cls) + SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy")) + vector_cls.assert_not_called() + + +def test_vectorize_summary_raises_for_blank_content() -> None: + with pytest.raises(ValueError, match="Summary content is empty"): + SummaryIndexService.vectorize_summary(_summary_record(summary_content=" "), _segment(), _dataset()) + + +def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [5] + model_manager = MagicMock() + model_manager.get_model_instance.return_value = embedding_model + monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + + vector_instance = MagicMock() + vector_instance.add_texts.side_effect = [RuntimeError("connection timeout"), None] + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + + session = MagicMock(name="provided_session") + merged = _summary_record(summary_content="sum") + session.merge.return_value = merged + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=session) + + assert vector_instance.add_texts.call_count == 2 + summary_module.time.sleep.assert_called_once() # type: ignore[attr-defined] + session.flush.assert_called_once() + assert summary.status == "completed" + assert summary.summary_index_node_id == "uuid-1" + assert summary.summary_index_node_hash == "hash-1" + assert summary.tokens == 5 + + +def test_vectorize_summary_without_session_creates_record_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id="old-node") + + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + # Force deletion branch to run and swallow delete failures. + vector_for_delete = MagicMock() + vector_for_delete.delete_by_ids.side_effect = RuntimeError("delete failed") + vector_for_add = MagicMock() + vector_for_add.add_texts.return_value = None + vector_cls = MagicMock(side_effect=[vector_for_delete, vector_for_add]) + monkeypatch.setattr(summary_module, "Vector", vector_cls) + + model_manager = MagicMock() + model_manager.get_model_instance.side_effect = RuntimeError("no model") + monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + + # New session used after vectorization succeeds (record not found by id nor chunk_id). + session = MagicMock(name="session") + q1 = MagicMock() + q1.filter_by.return_value = q1 + q1.first.side_effect = [None, None] + session.query.return_value = q1 + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + # One context for success path, no error handler session. + create_session_mock.assert_called() + session.add.assert_called() + session.commit.assert_called_once() + assert summary.status == "completed" + assert summary.summary_index_node_id == "old-node" # reused + + +def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + + vector_instance = MagicMock() + vector_instance.add_texts.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + + # error_session should find record and commit status update + error_session = MagicMock(name="error_session") + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = summary + error_session.query.return_value = q + + create_session_mock = MagicMock(return_value=_SessionContext(error_session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + assert summary.status == "error" + assert "Vectorization failed" in (summary.error or "") + error_session.commit.assert_called_once() + + +def test_batch_create_summary_records_no_segments_noop(monkeypatch: pytest.MonkeyPatch) -> None: + create_session_mock = MagicMock() + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + SummaryIndexService.batch_create_summary_records([], _dataset()) + create_session_mock.assert_not_called() + + +def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + s1 = _segment() + s2 = _segment() + s2.id = "seg-2" + s2.document_id = "doc-2" + + existing = _summary_record() + existing.chunk_id = "seg-2" + existing.enabled = False + + session = MagicMock() + query = MagicMock() + query.filter.return_value = query + query.all.return_value = [existing] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.batch_create_summary_records([s1, s2], dataset, status="not_started") + session.commit.assert_called_once() + assert existing.enabled is True + + +def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.update_summary_record_error(segment, dataset, "err") + assert record.status == "error" + assert record.error == "err" + session.commit.assert_called_once() + + +def test_generate_and_vectorize_summary_success(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", MagicMock(total_tokens=0))) + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + out = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert out is record + session.refresh.assert_called_once_with(record) + session.commit.assert_called() + + +def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", MagicMock(total_tokens=0))) + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert record.status == "error" + # Outer exception handler overwrites the error with the raw exception message. + assert record.error == "boom" + + +def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + vector_instance = MagicMock() + vector_instance.add_texts.return_value = None + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + existing = _summary_record(summary_content="old", node_id="old-node") + existing.id = "other-id" + session = MagicMock(name="session") + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, existing] # miss by id, hit by chunk_id + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + session.commit.assert_called_once() + assert existing.summary_index_node_id == "uuid-1" + + +def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + existing = _summary_record(summary_content="old", node_id="old-node") + session = MagicMock(name="session") + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = existing # hit by id + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + session.commit.assert_called_once() + assert existing.summary_index_node_hash == "hash-1" + + +def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + class _BadContext: + def __enter__(self): + return None + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + error_session = MagicMock() + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = summary + error_session.query.return_value = q + + create_session_mock = MagicMock(side_effect=[_BadContext(), _SessionContext(error_session)]) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + with pytest.raises(RuntimeError, match="Session should not be None"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + +def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + session = MagicMock() + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, None] # miss by id and chunk_id + session.query.return_value = q + + error_session = MagicMock() + eq = MagicMock() + eq.filter_by.return_value = eq + eq.first.return_value = summary + error_session.query.return_value = eq + + create_session_mock = MagicMock(side_effect=[_SessionContext(session), _SessionContext(error_session)]) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + # Force the created record to be None so the "should not be None" guard triggers. + monkeypatch.setattr(summary_module, "DocumentSegmentSummary", MagicMock(return_value=None)) + + with pytest.raises(RuntimeError, match="summary_record_in_session should not be None"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + +def test_vectorize_summary_error_handler_tries_chunk_id_lookup_and_can_warn_not_found( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + monkeypatch.setattr( + summary_module, + "Vector", + MagicMock(return_value=MagicMock(add_texts=MagicMock(side_effect=RuntimeError("boom")))), + ) + + error_session = MagicMock(name="error_session") + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, None] # not found by id, not found by chunk_id + error_session.query.return_value = q + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(error_session))), + ) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + # No record -> no commit in error session. + error_session.commit.assert_not_called() + + +def test_update_summary_record_error_warns_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + SummaryIndexService.update_summary_record_error(segment, dataset, "err") + logger_mock.warning.assert_called_once() + + +def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + usage = MagicMock(total_tokens=4, prompt_tokens=1, completion_tokens=3) + monkeypatch.setattr(SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", usage))) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + result = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert result.status in {"generating", "completed"} + logger_mock.info.assert_called() + + +def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset(indexing_technique="economy") + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + dataset = _dataset() + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == [] + + document.doc_form = "qa_model" + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + +def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + + seg1 = _segment() + seg2 = _segment() + seg2.id = "seg-2" + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [seg1, seg2] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr(SummaryIndexService, "batch_create_summary_records", MagicMock()) + monkeypatch.setattr( + SummaryIndexService, + "generate_and_vectorize_summary", + MagicMock(side_effect=[MagicMock(), RuntimeError("boom")]), + ) + update_err_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "update_summary_record_error", update_err_mock) + + records = SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) + assert len(records) == 1 + update_err_mock.assert_called_once() + + +def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + +def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chunks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + seg = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [seg] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + monkeypatch.setattr(SummaryIndexService, "batch_create_summary_records", MagicMock()) + monkeypatch.setattr(SummaryIndexService, "generate_and_vectorize_summary", MagicMock(return_value=MagicMock())) + + SummaryIndexService.generate_summaries_for_document( + dataset, + document, + {"enable": True}, + segment_ids=[seg.id], + only_parent_chunks=True, + ) + query.filter.assert_called() + + +def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary1 = _summary_record(summary_content="s", node_id="n1") + summary2 = _summary_record(summary_content="s", node_id=None) + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [summary1, summary2] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + summary_module, + "Vector", + MagicMock(return_value=MagicMock(delete_by_ids=MagicMock(side_effect=RuntimeError("boom")))), + ) + monkeypatch.setitem( + sys.modules, "libs.datetime_utils", SimpleNamespace(naive_utc_now=MagicMock(return_value=datetime(2024, 1, 1))) + ) + + SummaryIndexService.disable_summaries_for_segments(dataset, segment_ids=["seg-1"], disabled_by="u") + assert summary1.enabled is False + assert summary1.disabled_by == "u" + session.commit.assert_called_once() + + +def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setitem( + sys.modules, "libs.datetime_utils", SimpleNamespace(naive_utc_now=MagicMock(return_value=datetime(2024, 1, 1))) + ) + SummaryIndexService.disable_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_enable_summaries_for_segments_skips_non_high_quality() -> None: + SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy")) + + +def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary = _summary_record(summary_content="sum", node_id="n1") + summary.enabled = False + + segment = _segment() + segment.id = summary.chunk_id + segment.enabled = True + segment.status = "completed" + + session = MagicMock() + summary_query = MagicMock() + summary_query.filter_by.return_value = summary_query + summary_query.filter.return_value = summary_query + summary_query.all.return_value = [summary] + + seg_query = MagicMock() + seg_query.filter_by.return_value = seg_query + seg_query.first.return_value = segment + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + return summary_query + return seg_query + + session.query.side_effect = query_side_effect + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + vec_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vec_mock) + + SummaryIndexService.enable_summaries_for_segments(dataset, segment_ids=[summary.chunk_id]) + vec_mock.assert_called_once() + assert summary.enabled is True + session.commit.assert_called_once() + + +def test_enable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + SummaryIndexService.enable_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vectorize_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + summary1 = _summary_record(summary_content="sum", node_id="n1") + summary1.enabled = False + summary2 = _summary_record(summary_content="", node_id="n2") + summary2.enabled = False + summary3 = _summary_record(summary_content="sum3", node_id="n3") + summary3.enabled = False + + bad_segment = _segment() + bad_segment.enabled = False + bad_segment.status = "completed" + + good_segment = _segment() + good_segment.enabled = True + good_segment.status = "completed" + + session = MagicMock() + summary_query = MagicMock() + summary_query.filter_by.return_value = summary_query + summary_query.filter.return_value = summary_query + summary_query.all.return_value = [summary1, summary2, summary3] + + seg_query = MagicMock() + seg_query.filter_by.return_value = seg_query + seg_query.first.side_effect = [bad_segment, good_segment, good_segment] + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + return summary_query + return seg_query + + session.query.side_effect = query_side_effect + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + SummaryIndexService.enable_summaries_for_segments(dataset) + logger_mock.exception.assert_called_once() + session.commit.assert_called_once() + + +def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary = _summary_record(summary_content="sum", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [summary] + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids=[summary.chunk_id]) + vector_instance.delete_by_ids.assert_called_once_with(["n1"]) + session.delete.assert_called_once_with(summary) + session.commit.assert_called_once() + + +def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + SummaryIndexService.delete_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_update_summary_for_segment_skip_conditions() -> None: + assert ( + SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None + ) + seg = _segment(has_document=True) + seg.document.doc_form = "qa_model" + assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None + + +def test_update_summary_for_segment_empty_content_deletes_existing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, " ") is None + vector_instance.delete_by_ids.assert_called_once_with(["n1"]) + session.delete.assert_called_once_with(record) + session.commit.assert_called_once() + + +def test_update_summary_for_segment_empty_content_delete_vector_warns(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vector_instance = MagicMock() + vector_instance.delete_by_ids.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, "") is None + logger_mock.warning.assert_called() + + +def test_update_summary_for_segment_empty_content_no_record_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, " ") is None + + +def test_update_summary_for_segment_updates_existing_and_vectorizes(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vectorize_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new summary") + assert out is record + vectorize_mock.assert_called_once() + session.refresh.assert_called_once_with(record) + session.commit.assert_called() + + +def test_update_summary_for_segment_existing_vector_delete_warns(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vector_instance = MagicMock() + vector_instance.delete_by_ids.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + logger_mock.warning.assert_called() + + +def test_update_summary_for_segment_existing_vectorize_failure_returns_error_record( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out is record + assert out.status == "error" + assert "Vectorization failed" in (out.error or "") + + +def test_update_summary_for_segment_new_record_success(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + created = _summary_record(summary_content="new", node_id=None) + monkeypatch.setattr(SummaryIndexService, "create_summary_record", MagicMock(return_value=created)) + session.merge.return_value = created + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out is created + session.refresh.assert_called() + session.commit.assert_called() + + +def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + session.flush.side_effect = RuntimeError("flush boom") + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + with pytest.raises(RuntimeError, match="flush boom"): + SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert record.status == "error" + assert record.error == "flush boom" + session.commit.assert_called() + + +def test_get_segment_summary_and_document_summaries(monkeypatch: pytest.MonkeyPatch) -> None: + record = _summary_record(summary_content="sum", node_id="n1") + session = MagicMock() + + q1 = MagicMock() + q1.where.return_value = q1 + q1.first.return_value = record + + q2 = MagicMock() + q2.filter.return_value = q2 + q2.all.return_value = [record] + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + # first call used by get_segment_summary, second by get_document_summaries + if not hasattr(query_side_effect, "_called"): + query_side_effect._called = True # type: ignore[attr-defined] + return q1 + return q2 + return MagicMock() + + session.query.side_effect = query_side_effect + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.get_segment_summary("seg-1", "dataset-1") is record + assert SummaryIndexService.get_document_summaries("doc-1", "dataset-1", segment_ids=["seg-1"]) == [record] + + +def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> None: + record1 = _summary_record() + record1.chunk_id = "seg-1" + record2 = _summary_record() + record2.chunk_id = "seg-2" + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [record1, record2] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + out = SummaryIndexService.get_segments_summaries(["seg-1", "seg-2"], "dataset-1") + assert set(out.keys()) == {"seg-1", "seg-2"} + + +def test_get_document_summary_index_status_no_segments_returns_none(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") is None + + +def test_get_documents_summary_index_status_empty_input(monkeypatch: pytest.MonkeyPatch) -> None: + assert SummaryIndexService.get_documents_summary_index_status([], "dataset-1", "tenant-1") == {} + + +def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status="completed")}), + ) + result = SummaryIndexService.get_documents_summary_index_status(["doc-1"], "dataset-1", "tenant-1") + assert result["doc-1"] is None + + +def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_error_record( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + created = _summary_record(summary_content="new", node_id=None) + monkeypatch.setattr(SummaryIndexService, "create_summary_record", MagicMock(return_value=created)) + session.merge.return_value = created + + vectorize_mock = MagicMock(side_effect=RuntimeError("boom")) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out.status == "error" + assert "Vectorization failed" in (out.error or "") + + +def test_get_segments_summaries_empty_list() -> None: + assert SummaryIndexService.get_segments_summaries([], "dataset-1") == {} + + +def test_get_document_summary_index_status_and_documents_status(monkeypatch: pytest.MonkeyPatch) -> None: + seg_row = SimpleNamespace(id="seg-1", document_id="doc-1") + session = MagicMock() + query = MagicMock() + query.where.return_value = query + query.all.return_value = [SimpleNamespace(id="seg-1")] + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status="generating")}), + ) + assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING" + + # Multiple docs + query2 = MagicMock() + query2.where.return_value = query2 + query2.all.return_value = [seg_row] + session2 = MagicMock() + session2.query.return_value = query2 + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session2))), + ) + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status="not_started")}), + ) + result = SummaryIndexService.get_documents_summary_index_status(["doc-1", "doc-2"], "dataset-1", "tenant-1") + assert result["doc-1"] == "SUMMARIZING" + assert result["doc-2"] is None + + +def test_get_document_summary_status_detail_counts_and_previews(monkeypatch: pytest.MonkeyPatch) -> None: + segment1 = _segment() + segment1.id = "seg-1" + segment1.position = 1 + segment2 = _segment() + segment2.id = "seg-2" + segment2.position = 2 + + summary1 = _summary_record(summary_content="x" * 150, node_id="n1") + summary1.chunk_id = "seg-1" + summary1.status = "completed" + summary1.error = None + summary1.created_at = datetime(2024, 1, 1, tzinfo=UTC) + summary1.updated_at = datetime(2024, 1, 2, tzinfo=UTC) + + segment_service = SimpleNamespace(get_segments_by_document_and_dataset=MagicMock(return_value=[segment1, segment2])) + monkeypatch.setitem(sys.modules, "services.dataset_service", SimpleNamespace(SegmentService=segment_service)) + + monkeypatch.setattr(SummaryIndexService, "get_document_summaries", MagicMock(return_value=[summary1])) + + detail = SummaryIndexService.get_document_summary_status_detail("doc-1", "dataset-1") + assert detail["total_segments"] == 2 + assert detail["summary_status"]["completed"] == 1 + assert detail["summary_status"]["not_started"] == 1 + assert detail["summaries"][0]["summary_preview"].endswith("...") + assert detail["summaries"][1]["status"] == "not_started" diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py new file mode 100644 index 0000000000..7b0103a2a1 --- /dev/null +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -0,0 +1,704 @@ +"""Unit tests for `api/services/vector_service.py`.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import services.vector_service as vector_service_module +from services.vector_service import VectorService + + +@dataclass(frozen=True) +class _UploadFileStub: + id: str + name: str + + +@dataclass(frozen=True) +class _ChildDocStub: + page_content: str + metadata: dict[str, Any] + + +@dataclass +class _ParentDocStub: + children: list[_ChildDocStub] + + +def _make_dataset( + *, + indexing_technique: str = "high_quality", + doc_form: str = "text_model", + tenant_id: str = "tenant-1", + dataset_id: str = "dataset-1", + is_multimodal: bool = False, + embedding_model_provider: str | None = "openai", + embedding_model: str = "text-embedding", +) -> MagicMock: + dataset = MagicMock(name="dataset") + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.doc_form = doc_form + dataset.indexing_technique = indexing_technique + dataset.is_multimodal = is_multimodal + dataset.embedding_model_provider = embedding_model_provider + dataset.embedding_model = embedding_model + return dataset + + +def _make_segment( + *, + segment_id: str = "seg-1", + tenant_id: str = "tenant-1", + dataset_id: str = "dataset-1", + document_id: str = "doc-1", + content: str = "hello", + index_node_id: str = "node-1", + index_node_hash: str = "hash-1", + attachments: list[dict[str, str]] | None = None, +) -> MagicMock: + segment = MagicMock(name="segment") + segment.id = segment_id + segment.tenant_id = tenant_id + segment.dataset_id = dataset_id + segment.document_id = document_id + segment.content = content + segment.index_node_id = index_node_id + segment.index_node_hash = index_node_hash + segment.attachments = attachments or [] + return segment + + +def _mock_db_session_for_update_multimodel(*, upload_files: list[_UploadFileStub] | None) -> MagicMock: + session = MagicMock(name="session") + + binding_query = MagicMock(name="binding_query") + binding_query.where.return_value = binding_query + binding_query.delete.return_value = 1 + + upload_query = MagicMock(name="upload_query") + upload_query.where.return_value = upload_query + upload_query.all.return_value = upload_files or [] + + def query_side_effect(model: object) -> MagicMock: + if model is vector_service_module.SegmentAttachmentBinding: + return binding_query + if model is vector_service_module.UploadFile: + return upload_query + return MagicMock(name=f"query({model})") + + session.query.side_effect = query_side_effect + db_mock = MagicMock(name="db") + db_mock.session = session + return db_mock + + +def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(is_multimodal=False) + segment = _make_segment() + + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock(name="IndexProcessorFactory-instance") + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + + index_processor.load.assert_called_once() + args, kwargs = index_processor.load.call_args + assert args[0] == dataset + assert len(args[1]) == 1 + assert args[2] is None + assert kwargs["with_keywords"] is True + assert kwargs["keywords_list"] == [["k1"]] + + +def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(is_multimodal=True) + segment = _make_segment( + attachments=[ + {"id": "img-1", "name": "a.png"}, + {"id": "img-2", "name": "b.png"}, + ] + ) + + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock(name="IndexProcessorFactory-instance") + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + + assert index_processor.load.call_count == 2 + first_args, first_kwargs = index_processor.load.call_args_list[0] + assert first_args[0] == dataset + assert len(first_args[1]) == 1 + assert first_kwargs["with_keywords"] is True + + second_args, second_kwargs = index_processor.load.call_args_list[1] + assert second_args[0] == dataset + assert second_args[1] == [] + assert len(second_args[2]) == 2 + assert second_kwargs["with_keywords"] is False + + +def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset() + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector(None, [], dataset, "text_model") + index_processor.load.assert_not_called() + + +def _mock_parent_child_queries( + *, + dataset_document: object | None, + processing_rule: object | None, +) -> MagicMock: + session = MagicMock(name="session") + + doc_query = MagicMock(name="doc_query") + doc_query.filter_by.return_value = doc_query + doc_query.first.return_value = dataset_document + + rule_query = MagicMock(name="rule_query") + rule_query.where.return_value = rule_query + rule_query.first.return_value = processing_rule + + def query_side_effect(model: object) -> MagicMock: + if model is vector_service_module.DatasetDocument: + return doc_query + if model is vector_service_module.DatasetProcessRule: + return rule_query + return MagicMock(name=f"query({model})") + + session.query.side_effect = query_side_effect + db_mock = MagicMock(name="db") + db_mock.session = session + return db_mock + + +def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_explicit_model( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + embedding_model_provider="openai", + indexing_technique="high_quality", + ) + segment = _make_segment() + + dataset_document = MagicMock(name="dataset_document") + dataset_document.id = segment.document_id + dataset_document.dataset_process_rule_id = "rule-1" + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock(name="processing_rule") + processing_rule.to_dict.return_value = {"rules": {}} + + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + embedding_model_instance = MagicMock(name="embedding_model_instance") + model_manager_instance = MagicMock(name="model_manager_instance") + model_manager_instance.get_model_instance.return_value = embedding_model_instance + monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + + generate_child_chunks_mock = MagicMock() + monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + model_manager_instance.get_model_instance.assert_called_once() + generate_child_chunks_mock.assert_called_once_with( + segment, dataset_document, dataset, embedding_model_instance, processing_rule, False + ) + index_processor.load.assert_not_called() + + +def test_create_segments_vector_parent_child_uses_default_embedding_model_when_provider_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + embedding_model_provider=None, + indexing_technique="high_quality", + ) + segment = _make_segment() + + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + embedding_model_instance = MagicMock() + model_manager_instance = MagicMock() + model_manager_instance.get_default_model_instance.return_value = embedding_model_instance + monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + + generate_child_chunks_mock = MagicMock() + monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + model_manager_instance.get_default_model_instance.assert_called_once() + generate_child_chunks_mock.assert_called_once() + + +def test_create_segments_vector_parent_child_missing_document_logs_warning_and_continues( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX) + segment = _make_segment() + + processing_rule = MagicMock() + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=None, processing_rule=processing_rule), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + logger_mock.warning.assert_called_once() + index_processor.load.assert_not_called() + + +def test_create_segments_vector_parent_child_missing_processing_rule_raises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX) + segment = _make_segment() + + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=None), + ) + + with pytest.raises(ValueError, match="No processing rule found"): + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + +def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + indexing_technique="economy", + ) + segment = _make_segment() + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + processing_rule = MagicMock() + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + with pytest.raises(ValueError, match="not high quality"): + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + +def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + segment = _make_segment() + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.update_segment_vector(["k"], segment, dataset) + + vector_instance.delete_by_ids.assert_called_once_with([segment.index_node_id]) + vector_instance.add_texts.assert_called_once() + add_args, add_kwargs = vector_instance.add_texts.call_args + assert len(add_args[0]) == 1 + assert add_kwargs["duplicate_check"] is True + + +def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + segment = _make_segment() + + keyword_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance)) + + VectorService.update_segment_vector(["a", "b"], segment, dataset) + + keyword_instance.delete_by_ids.assert_called_once_with([segment.index_node_id]) + keyword_instance.add_texts.assert_called_once() + args, kwargs = keyword_instance.add_texts.call_args + assert len(args[0]) == 1 + assert kwargs["keywords_list"] == [["a", "b"]] + + +def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + segment = _make_segment() + + keyword_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance)) + + VectorService.update_segment_vector(None, segment, dataset) + keyword_instance.add_texts.assert_called_once() + _, kwargs = keyword_instance.add_texts.call_args + assert "keywords_list" not in kwargs + + +def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1") + segment = _make_segment(segment_id="seg-1") + + dataset_document = MagicMock() + dataset_document.id = segment.document_id + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + child1 = _ChildDocStub(page_content="c1", metadata={"doc_id": "c1-id", "doc_hash": "c1-h"}) + child2 = _ChildDocStub(page_content="c2", metadata={"doc_id": "c2-id", "doc_hash": "c2-h"}) + transformed = [_ParentDocStub(children=[child1, child2])] + + index_processor = MagicMock() + index_processor.transform.return_value = transformed + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + child_chunk_ctor = MagicMock(side_effect=lambda **kwargs: kwargs) + monkeypatch.setattr(vector_service_module, "ChildChunk", child_chunk_ctor) + + db_mock = MagicMock() + db_mock.session.add = MagicMock() + db_mock.session.commit = MagicMock() + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.generate_child_chunks( + segment=segment, + dataset_document=dataset_document, + dataset=dataset, + embedding_model_instance=MagicMock(), + processing_rule=processing_rule, + regenerate=True, + ) + + index_processor.clean.assert_called_once() + _, transform_kwargs = index_processor.transform.call_args + assert transform_kwargs["process_rule"]["rules"]["parent_mode"] == vector_service_module.ParentMode.FULL_DOC + index_processor.load.assert_called_once() + assert db_mock.session.add.call_count == 2 + db_mock.session.commit.assert_called_once() + + +def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form="text_model") + segment = _make_segment() + dataset_document = MagicMock() + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + index_processor = MagicMock() + index_processor.transform.return_value = [_ParentDocStub(children=[])] + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + db_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.generate_child_chunks( + segment=segment, + dataset_document=dataset_document, + dataset=dataset, + embedding_model_instance=MagicMock(), + processing_rule=processing_rule, + regenerate=False, + ) + + index_processor.load.assert_not_called() + db_mock.session.add.assert_not_called() + db_mock.session.commit.assert_called_once() + + +def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + child_chunk = MagicMock() + child_chunk.content = "child" + child_chunk.index_node_id = "id" + child_chunk.index_node_hash = "h" + child_chunk.document_id = "doc-1" + child_chunk.dataset_id = "dataset-1" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.create_child_chunk_vector(child_chunk, dataset) + vector_instance.add_texts.assert_called_once() + + +def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + vector_cls = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + + child_chunk = MagicMock() + child_chunk.content = "child" + child_chunk.index_node_id = "id" + child_chunk.index_node_hash = "h" + child_chunk.document_id = "doc-1" + child_chunk.dataset_id = "dataset-1" + + VectorService.create_child_chunk_vector(child_chunk, dataset) + vector_cls.assert_not_called() + + +def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + + new_chunk = MagicMock() + new_chunk.content = "n" + new_chunk.index_node_id = "nid" + new_chunk.index_node_hash = "nh" + new_chunk.document_id = "d" + new_chunk.dataset_id = "ds" + + upd_chunk = MagicMock() + upd_chunk.content = "u" + upd_chunk.index_node_id = "uid" + upd_chunk.index_node_hash = "uh" + upd_chunk.document_id = "d" + upd_chunk.dataset_id = "ds" + + del_chunk = MagicMock() + del_chunk.index_node_id = "did" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.update_child_chunk_vector([new_chunk], [upd_chunk], [del_chunk], dataset) + + vector_instance.delete_by_ids.assert_called_once_with(["uid", "did"]) + vector_instance.add_texts.assert_called_once() + docs = vector_instance.add_texts.call_args.args[0] + assert len(docs) == 2 + + +def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + vector_cls = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + VectorService.update_child_chunk_vector([], [], [], dataset) + vector_cls.assert_not_called() + + +def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset() + child_chunk = MagicMock() + child_chunk.index_node_id = "cid" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.delete_child_chunk_vector(child_chunk, dataset) + vector_instance.delete_by_ids.assert_called_once_with(["cid"]) + + +# --------------------------------------------------------------------------- +# update_multimodel_vector (missing coverage in previous suites) +# --------------------------------------------------------------------------- + + +def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy", is_multimodal=True) + segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}]) + + vector_cls = MagicMock() + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["a"], dataset=dataset) + vector_cls.assert_not_called() + db_mock.session.query.assert_not_called() + + +def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}]) + + vector_cls = MagicMock() + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["b", "a"], dataset=dataset) + vector_cls.assert_not_called() + db_mock.session.query.assert_not_called() + + +def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}]) + + vector_instance = MagicMock(name="vector_instance") + vector_cls = MagicMock(return_value=vector_instance) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=[], dataset=dataset) + + vector_cls.assert_called_once_with(dataset=dataset) + vector_instance.delete_by_ids.assert_called_once_with(["old-1", "old-2"]) + db_mock.session.query.assert_called_once_with(vector_service_module.SegmentAttachmentBinding) + db_mock.session.commit.assert_called_once() + db_mock.session.add_all.assert_not_called() + vector_instance.add_texts.assert_not_called() + + +def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["new-1"], dataset=dataset) + + db_mock.session.commit.assert_called_once() + db_mock.session.add_all.assert_not_called() + vector_instance.add_texts.assert_not_called() + + +def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + binding_ctor = MagicMock(side_effect=lambda **kwargs: kwargs) + monkeypatch.setattr(vector_service_module, "SegmentAttachmentBinding", binding_ctor) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1", "missing"], dataset=dataset) + + logger_mock.warning.assert_called_once() + db_mock.session.add_all.assert_called_once() + bindings = db_mock.session.add_all.call_args.args[0] + assert len(bindings) == 1 + assert bindings[0]["attachment_id"] == "file-1" + + vector_instance.add_texts.assert_called_once() + documents = vector_instance.add_texts.call_args.args[0] + assert len(documents) == 1 + assert documents[0].page_content == "img.png" + assert documents[0].metadata["doc_id"] == "file-1" + db_mock.session.commit.assert_called_once() + + +def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + monkeypatch.setattr( + vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) + ) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) + + vector_instance.delete_by_ids.assert_not_called() + vector_instance.add_texts.assert_not_called() + db_mock.session.add_all.assert_called_once() + db_mock.session.commit.assert_called_once() + + +def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + db_mock.session.commit.side_effect = RuntimeError("boom") + monkeypatch.setattr(vector_service_module, "db", db_mock) + monkeypatch.setattr( + vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) + ) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + with pytest.raises(RuntimeError, match="boom"): + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) + + logger_mock.exception.assert_called_once() + db_mock.session.rollback.assert_called_once() diff --git a/api/tests/unit_tests/services/test_website_service.py b/api/tests/unit_tests/services/test_website_service.py new file mode 100644 index 0000000000..38d94f4736 --- /dev/null +++ b/api/tests/unit_tests/services/test_website_service.py @@ -0,0 +1,718 @@ +"""Unit tests for services.website_service. + +Focuses on provider dispatching, argument validation, and provider-specific branches +without making any real network/storage/redis calls. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +import services.website_service as website_service_module +from services.website_service import ( + CrawlOptions, + WebsiteCrawlApiRequest, + WebsiteCrawlStatusApiRequest, + WebsiteService, +) + + +@dataclass(frozen=True) +class _DummyHttpxResponse: + payload: dict[str, Any] + + def json(self) -> dict[str, Any]: + return self.payload + + +@pytest.fixture(autouse=True) +def stub_current_user(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module, + "current_user", + type("User", (), {"current_tenant_id": "tenant-1"})(), + ) + + +def test_crawl_options_include_exclude_paths() -> None: + options = CrawlOptions(includes="a,b", excludes="x,y") + assert options.get_include_paths() == ["a", "b"] + assert options.get_exclude_paths() == ["x", "y"] + + empty = CrawlOptions(includes=None, excludes=None) + assert empty.get_include_paths() == [] + assert empty.get_exclude_paths() == [] + + +def test_website_crawl_api_request_from_args_valid_and_to_crawl_request() -> None: + args = { + "provider": "firecrawl", + "url": "https://example.com", + "options": { + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a,b", + "excludes": "x", + "prompt": "hi", + "max_depth": 3, + "use_sitemap": False, + }, + } + + api_req = WebsiteCrawlApiRequest.from_args(args) + crawl_req = api_req.to_crawl_request() + + assert crawl_req.provider == "firecrawl" + assert crawl_req.url == "https://example.com" + assert crawl_req.options.limit == 2 + assert crawl_req.options.crawl_sub_pages is True + assert crawl_req.options.only_main_content is True + assert crawl_req.options.get_include_paths() == ["a", "b"] + assert crawl_req.options.get_exclude_paths() == ["x"] + assert crawl_req.options.prompt == "hi" + assert crawl_req.options.max_depth == 3 + assert crawl_req.options.use_sitemap is False + + +@pytest.mark.parametrize( + ("args", "missing_msg"), + [ + ({}, "Provider is required"), + ({"provider": "firecrawl"}, "URL is required"), + ({"provider": "firecrawl", "url": "https://example.com"}, "Options are required"), + ], +) +def test_website_crawl_api_request_from_args_requires_fields(args: dict, missing_msg: str) -> None: + with pytest.raises(ValueError, match=missing_msg): + WebsiteCrawlApiRequest.from_args(args) + + +def test_website_crawl_status_api_request_from_args_requires_fields() -> None: + with pytest.raises(ValueError, match="Provider is required"): + WebsiteCrawlStatusApiRequest.from_args({}, job_id="job-1") + + with pytest.raises(ValueError, match="Job ID is required"): + WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="") + + req = WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="job-1") + assert req.provider == "firecrawl" + assert req.job_id == "job-1" + + +def test_get_credentials_and_config_selects_plugin_id_and_key_firecrawl(monkeypatch: pytest.MonkeyPatch) -> None: + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k", "base_url": "b"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + api_key, config = WebsiteService._get_credentials_and_config("tenant-1", "firecrawl") + assert api_key == "k" + assert config["base_url"] == "b" + + service_instance.get_datasource_credentials.assert_called_once_with( + tenant_id="tenant-1", + provider="firecrawl", + plugin_id="langgenius/firecrawl_datasource", + ) + + +@pytest.mark.parametrize( + ("provider", "plugin_id"), + [ + ("watercrawl", "langgenius/watercrawl_datasource"), + ("jinareader", "langgenius/jina_datasource"), + ], +) +def test_get_credentials_and_config_selects_plugin_id_and_key_api_key( + monkeypatch: pytest.MonkeyPatch, provider: str, plugin_id: str +) -> None: + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"api_key": "enc-key", "base_url": "b"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + api_key, config = WebsiteService._get_credentials_and_config("tenant-1", provider) + assert api_key == "enc-key" + assert config["base_url"] == "b" + + service_instance.get_datasource_credentials.assert_called_once_with( + tenant_id="tenant-1", + provider=provider, + plugin_id=plugin_id, + ) + + +def test_get_credentials_and_config_rejects_invalid_provider() -> None: + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService._get_credentials_and_config("tenant-1", "unknown") + + +def test_get_credentials_and_config_hits_unreachable_guard_branch(monkeypatch: pytest.MonkeyPatch) -> None: + class FlakyProvider: + def __init__(self) -> None: + self._eq_calls = 0 + + def __hash__(self) -> int: + return 1 + + def __eq__(self, other: object) -> bool: + if other == "firecrawl": + self._eq_calls += 1 + return self._eq_calls == 1 + return False + + def __repr__(self) -> str: + return "FlakyProvider()" + + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService._get_credentials_and_config("tenant-1", FlakyProvider()) # type: ignore[arg-type] + + +def test_get_decrypted_api_key_requires_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", MagicMock()) + with pytest.raises(ValueError, match="API key not found in configuration"): + WebsiteService._get_decrypted_api_key("tenant-1", {}) + + +def test_get_decrypted_api_key_decrypts(monkeypatch: pytest.MonkeyPatch) -> None: + decrypt_mock = MagicMock(return_value="plain") + monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", decrypt_mock) + + assert WebsiteService._get_decrypted_api_key("tenant-1", {"api_key": "enc"}) == "plain" + decrypt_mock.assert_called_once_with(tenant_id="tenant-1", token="enc") + + +def test_document_create_args_validate_wraps_error_message() -> None: + with pytest.raises(ValueError, match=r"^Invalid arguments: Provider is required$"): + WebsiteService.document_create_args_validate({}) + + +def test_crawl_url_dispatches_by_provider(monkeypatch: pytest.MonkeyPatch) -> None: + api_request = WebsiteCrawlApiRequest(provider="firecrawl", url="https://example.com", options={"limit": 1}) + crawl_request = api_request.to_crawl_request() + + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + firecrawl_mock = MagicMock(return_value={"status": "active", "job_id": "j1"}) + monkeypatch.setattr(WebsiteService, "_crawl_with_firecrawl", firecrawl_mock) + + result = WebsiteService.crawl_url(api_request) + + assert result == {"status": "active", "job_id": "j1"} + firecrawl_mock.assert_called_once() + assert firecrawl_mock.call_args.kwargs["request"] == crawl_request + + +@pytest.mark.parametrize( + ("provider", "method_name"), + [ + ("watercrawl", "_crawl_with_watercrawl"), + ("jinareader", "_crawl_with_jinareader"), + ], +) +def test_crawl_url_dispatches_other_providers(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None: + api_request = WebsiteCrawlApiRequest(provider=provider, url="https://example.com", options={"limit": 1}) + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + + impl_mock = MagicMock(return_value={"status": "active"}) + monkeypatch.setattr(WebsiteService, method_name, impl_mock) + + assert WebsiteService.crawl_url(api_request) == {"status": "active"} + impl_mock.assert_called_once() + + +def test_crawl_url_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + api_request = WebsiteCrawlApiRequest(provider="bad", url="https://example.com", options={"limit": 1}) + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.crawl_url(api_request) + + +def test_crawl_with_firecrawl_builds_params_single_page_and_sets_redis(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock(name="FirecrawlApp-instance") + firecrawl_instance.crawl_url.return_value = "job-1" + firecrawl_cls = MagicMock(return_value=firecrawl_instance) + monkeypatch.setattr(website_service_module, "FirecrawlApp", firecrawl_cls) + + redis_mock = MagicMock() + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + fixed_now = datetime(2024, 1, 1, tzinfo=UTC) + with patch.object(website_service_module.datetime, "datetime") as datetime_mock: + datetime_mock.now.return_value = fixed_now + + req = WebsiteCrawlApiRequest( + provider="firecrawl", url="https://example.com", options={"limit": 5} + ).to_crawl_request() + req.options.crawl_sub_pages = False + req.options.only_main_content = True + + result = WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": "b"}) + + assert result == {"status": "active", "job_id": "job-1"} + + firecrawl_cls.assert_called_once_with(api_key="k", base_url="b") + firecrawl_instance.crawl_url.assert_called_once() + _, params = firecrawl_instance.crawl_url.call_args.args + assert params["limit"] == 1 + assert params["includePaths"] == [] + assert params["excludePaths"] == [] + assert params["scrapeOptions"] == {"onlyMainContent": True} + + redis_mock.setex.assert_called_once() + key, ttl, value = redis_mock.setex.call_args.args + assert key == "website_crawl_job-1" + assert ttl == 3600 + assert float(value) == pytest.approx(fixed_now.timestamp(), rel=0, abs=1e-6) + + +def test_crawl_with_firecrawl_builds_params_multi_page_including_prompt(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock(name="FirecrawlApp-instance") + firecrawl_instance.crawl_url.return_value = "job-2" + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + monkeypatch.setattr(website_service_module, "redis_client", MagicMock()) + + req = WebsiteCrawlApiRequest( + provider="firecrawl", + url="https://example.com", + options={ + "crawl_sub_pages": True, + "limit": 3, + "only_main_content": False, + "includes": "a,b", + "excludes": "x", + "prompt": "use this", + }, + ).to_crawl_request() + + WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": None}) + _, params = firecrawl_instance.crawl_url.call_args.args + assert params["includePaths"] == ["a", "b"] + assert params["excludePaths"] == ["x"] + assert params["limit"] == 3 + assert params["scrapeOptions"] == {"onlyMainContent": False} + assert params["prompt"] == "use this" + + +def test_crawl_with_watercrawl_passes_options_dict(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.crawl_url.return_value = {"status": "active", "job_id": "w1"} + provider_cls = MagicMock(return_value=provider_instance) + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", provider_cls) + + req = WebsiteCrawlApiRequest( + provider="watercrawl", + url="https://example.com", + options={ + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a", + "excludes": None, + "max_depth": 5, + "use_sitemap": False, + }, + ).to_crawl_request() + + result = WebsiteService._crawl_with_watercrawl(request=req, api_key="k", config={"base_url": "b"}) + assert result == {"status": "active", "job_id": "w1"} + + provider_cls.assert_called_once_with(api_key="k", base_url="b") + provider_instance.crawl_url.assert_called_once_with( + url="https://example.com", + options={ + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a", + "excludes": None, + "max_depth": 5, + "use_sitemap": False, + }, + ) + + +def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPatch) -> None: + get_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"title": "t"}})) + monkeypatch.setattr(website_service_module.httpx, "get", get_mock) + + req = WebsiteCrawlApiRequest( + provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} + ).to_crawl_request() + req.options.crawl_sub_pages = False + + result = WebsiteService._crawl_with_jinareader(request=req, api_key="k") + assert result == {"status": "active", "data": {"title": "t"}} + get_mock.assert_called_once() + + +def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + req = WebsiteCrawlApiRequest( + provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} + ).to_crawl_request() + req.options.crawl_sub_pages = False + + with pytest.raises(ValueError, match="Failed to crawl:"): + WebsiteService._crawl_with_jinareader(request=req, api_key="k") + + +def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"taskId": "t1"}})) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + req = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={"crawl_sub_pages": True, "limit": 5, "use_sitemap": True}, + ).to_crawl_request() + req.options.crawl_sub_pages = True + + result = WebsiteService._crawl_with_jinareader(request=req, api_key="k") + assert result == {"status": "active", "job_id": "t1"} + post_mock.assert_called_once() + + +def test_crawl_with_jinareader_multi_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module.httpx, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400})) + ) + req = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={"crawl_sub_pages": True, "limit": 2, "use_sitemap": False}, + ).to_crawl_request() + req.options.crawl_sub_pages = True + + with pytest.raises(ValueError, match="Failed to crawl$"): + WebsiteService._crawl_with_jinareader(request=req, api_key="k") + + +def test_get_crawl_status_dispatches(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + firecrawl_status = MagicMock(return_value={"status": "active"}) + monkeypatch.setattr(WebsiteService, "_get_firecrawl_status", firecrawl_status) + + result = WebsiteService.get_crawl_status("job-1", "firecrawl") + assert result == {"status": "active"} + firecrawl_status.assert_called_once_with("job-1", "k", {"base_url": "b"}) + + watercrawl_status = MagicMock(return_value={"status": "active", "job_id": "w"}) + monkeypatch.setattr(WebsiteService, "_get_watercrawl_status", watercrawl_status) + assert WebsiteService.get_crawl_status("job-2", "watercrawl") == {"status": "active", "job_id": "w"} + watercrawl_status.assert_called_once_with("job-2", "k", {"base_url": "b"}) + + jinareader_status = MagicMock(return_value={"status": "active", "job_id": "j"}) + monkeypatch.setattr(WebsiteService, "_get_jinareader_status", jinareader_status) + assert WebsiteService.get_crawl_status("job-3", "jinareader") == {"status": "active", "job_id": "j"} + jinareader_status.assert_called_once_with("job-3", "k") + + +def test_get_crawl_status_typed_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_status_typed(WebsiteCrawlStatusApiRequest(provider="bad", job_id="j")) + + +def test_get_firecrawl_status_adds_time_consuming_when_completed_and_cached(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 2, "current": 2, "data": []} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + redis_mock = MagicMock() + redis_mock.get.return_value = b"100.0" + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + with patch.object(website_service_module.datetime, "datetime") as datetime_mock: + datetime_mock.now.return_value = datetime.fromtimestamp(105.0, tz=UTC) + result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": "b"}) + + assert result["status"] == "completed" + assert result["time_consuming"] == "5.00" + redis_mock.delete.assert_called_once_with("website_crawl_job-1") + + +def test_get_firecrawl_status_completed_without_cache_does_not_add_time(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + redis_mock = MagicMock() + redis_mock.get.return_value = None + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": None}) + assert result["status"] == "completed" + assert "time_consuming" not in result + redis_mock.delete.assert_not_called() + + +def test_get_watercrawl_status_delegates(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.get_crawl_status.return_value = {"status": "active", "job_id": "w1"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + assert WebsiteService._get_watercrawl_status("job-1", "k", {"base_url": "b"}) == { + "status": "active", + "job_id": "w1", + } + provider_instance.get_crawl_status.assert_called_once_with("job-1") + + +def test_get_jinareader_status_active(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock( + return_value=_DummyHttpxResponse( + { + "data": { + "status": "active", + "urls": ["a", "b"], + "processed": {"a": {}}, + "failed": {"b": {}}, + "duration": 3000, + } + } + ) + ) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + result = WebsiteService._get_jinareader_status("job-1", "k") + assert result["status"] == "active" + assert result["total"] == 2 + assert result["current"] == 2 + assert result["time_consuming"] == 3.0 + assert result["data"] == [] + post_mock.assert_called_once() + + +def test_get_jinareader_status_completed_formats_processed_items(monkeypatch: pytest.MonkeyPatch) -> None: + status_payload = { + "data": { + "status": "completed", + "urls": ["u1"], + "processed": {"u1": {}}, + "failed": {}, + "duration": 1000, + } + } + processed_payload = { + "data": { + "processed": { + "u1": { + "data": { + "title": "t", + "url": "u1", + "description": "d", + "content": "md", + } + } + } + } + } + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + result = WebsiteService._get_jinareader_status("job-1", "k") + assert result["status"] == "completed" + assert result["data"] == [{"title": "t", "source_url": "u1", "description": "d", "markdown": "md"}] + assert post_mock.call_count == 2 + + +def test_get_crawl_url_data_dispatches_invalid_provider() -> None: + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_url_data("job-1", "bad", "https://example.com", "tenant-1") + + +def test_get_crawl_url_data_hits_invalid_provider_branch_when_credentials_stubbed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_url_data("job-1", object(), "u", "tenant-1") # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("provider", "method_name"), + [ + ("firecrawl", "_get_firecrawl_url_data"), + ("watercrawl", "_get_watercrawl_url_data"), + ("jinareader", "_get_jinareader_url_data"), + ], +) +def test_get_crawl_url_data_dispatches(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + impl_mock = MagicMock(return_value={"ok": True}) + monkeypatch.setattr(WebsiteService, method_name, impl_mock) + + result = WebsiteService.get_crawl_url_data("job-1", provider, "u", "tenant-1") + assert result == {"ok": True} + impl_mock.assert_called_once() + + +def test_get_firecrawl_url_data_reads_from_storage_when_present(monkeypatch: pytest.MonkeyPatch) -> None: + stored_list = [{"source_url": "https://example.com", "title": "t"}] + stored = json.dumps(stored_list).encode("utf-8") + + storage_mock = MagicMock() + storage_mock.exists.return_value = True + storage_mock.load_once.return_value = stored + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock()) + + result = WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"}) + assert result == {"source_url": "https://example.com", "title": "t"} + assert result is not stored_list[0] + + +def test_get_firecrawl_url_data_returns_none_when_storage_empty(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = True + storage_mock.load_once.return_value = b"" + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {}) is None + + +def test_get_firecrawl_url_data_raises_when_job_not_completed(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = False + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "active"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + with pytest.raises(ValueError, match="Crawl job is not completed"): + WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": None}) + + +def test_get_firecrawl_url_data_returns_none_when_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = False + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "data": [{"source_url": "x"}]} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"}) is None + + +def test_get_watercrawl_url_data_delegates(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.get_crawl_url_data.return_value = {"source_url": "u"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + result = WebsiteService._get_watercrawl_url_data("job-1", "u", "k", {"base_url": "b"}) + assert result == {"source_url": "u"} + provider_instance.get_crawl_url_data.assert_called_once_with("job-1", "u") + + +def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module.httpx, + "get", + MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"url": "u"}})), + ) + assert WebsiteService._get_jinareader_url_data("", "u", "k") == {"url": "u"} + + +def test_get_jinareader_url_data_without_job_id_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + with pytest.raises(ValueError, match="Failed to crawl$"): + WebsiteService._get_jinareader_url_data("", "u", "k") + + +def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(monkeypatch: pytest.MonkeyPatch) -> None: + status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}} + processed_payload = {"data": {"processed": {"u1": {"data": {"url": "u", "title": "t"}}}}} + + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") == {"url": "u", "title": "t"} + assert post_mock.call_count == 2 + + +def test_get_jinareader_url_data_with_job_id_not_completed_raises(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock(return_value=_DummyHttpxResponse({"data": {"status": "active"}})) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + with pytest.raises(ValueError, match=r"Crawl job is no\s*t completed"): + WebsiteService._get_jinareader_url_data("job-1", "u", "k") + + +def test_get_jinareader_url_data_with_job_id_completed_but_not_found_returns_none( + monkeypatch: pytest.MonkeyPatch, +) -> None: + status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}} + processed_payload = {"data": {"processed": {"u1": {"data": {"url": "other"}}}}} + + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") is None + + +def test_get_scrape_url_data_dispatches_and_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + + scrape_mock = MagicMock(return_value={"data": "x"}) + monkeypatch.setattr(WebsiteService, "_scrape_with_firecrawl", scrape_mock) + assert WebsiteService.get_scrape_url_data("firecrawl", "u", "tenant-1", True) == {"data": "x"} + scrape_mock.assert_called_once() + + watercrawl_mock = MagicMock(return_value={"data": "y"}) + monkeypatch.setattr(WebsiteService, "_scrape_with_watercrawl", watercrawl_mock) + assert WebsiteService.get_scrape_url_data("watercrawl", "u", "tenant-1", False) == {"data": "y"} + watercrawl_mock.assert_called_once() + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_scrape_url_data("jinareader", "u", "tenant-1", True) + + +def test_scrape_with_firecrawl_calls_app(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.scrape_url.return_value = {"markdown": "m"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + result = WebsiteService._scrape_with_firecrawl( + request=website_service_module.ScrapeRequest( + provider="firecrawl", + url="u", + tenant_id="tenant-1", + only_main_content=True, + ), + api_key="k", + config={"base_url": "b"}, + ) + assert result == {"markdown": "m"} + firecrawl_instance.scrape_url.assert_called_once_with(url="u", params={"onlyMainContent": True}) + + +def test_scrape_with_watercrawl_calls_provider(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.scrape_url.return_value = {"markdown": "m"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + result = WebsiteService._scrape_with_watercrawl( + request=website_service_module.ScrapeRequest( + provider="watercrawl", + url="u", + tenant_id="tenant-1", + only_main_content=False, + ), + api_key="k", + config={"base_url": "b"}, + ) + assert result == {"markdown": "m"} + provider_instance.scrape_url.assert_called_once_with("u")