From 157208ab1e01ddc530a79982bd2ce2ddb5eb720e Mon Sep 17 00:00:00 2001 From: mahammadasim <135003320+mahammadasim@users.noreply.github.com> Date: Thu, 12 Mar 2026 13:04:20 +0530 Subject: [PATCH 1/3] =?UTF-8?q?test:=20added=20test=20for=20services=20of?= =?UTF-8?q?=20ops,=20summary,=20vector,=20website=20and=20ji=E2=80=A6=20(#?= =?UTF-8?q?32893)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: akashseth-ifp --- .../auth/test_jina_auth_standalone_module.py | 157 ++ .../unit_tests/services/test_ops_service.py | 381 +++++ .../services/test_summary_index_service.py | 1329 +++++++++++++++++ .../services/test_vector_service.py | 704 +++++++++ .../services/test_website_service.py | 718 +++++++++ 5 files changed, 3289 insertions(+) create mode 100644 api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py create mode 100644 api/tests/unit_tests/services/test_ops_service.py create mode 100644 api/tests/unit_tests/services/test_summary_index_service.py create mode 100644 api/tests/unit_tests/services/test_vector_service.py create mode 100644 api/tests/unit_tests/services/test_website_service.py diff --git a/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py new file mode 100644 index 0000000000..c2fcd71875 --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType +from unittest.mock import MagicMock + +import httpx +import pytest + + +@pytest.fixture(scope="module") +def jina_module() -> ModuleType: + """ + Load `api/services/auth/jina.py` as a standalone module. + + This repo contains both `services/auth/jina.py` and a package at + `services/auth/jina/`, so importing `services.auth.jina` can be ambiguous. + """ + + module_path = Path(__file__).resolve().parents[4] / "services" / "auth" / "jina.py" + # Use a stable module name so pytest-cov can target it with `--cov=services.auth.jina_file`. + spec = importlib.util.spec_from_file_location("services.auth.jina_file", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _credentials(api_key: str | None = "test_api_key_123", auth_type: str = "bearer") -> dict: + config: dict = {} if api_key is None else {"api_key": api_key} + return {"auth_type": auth_type, "config": config} + + +def test_init_valid_bearer_credentials(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials()) + assert auth.api_key == "test_api_key_123" + assert auth.credentials["auth_type"] == "bearer" + + +def test_init_rejects_invalid_auth_type(jina_module: ModuleType) -> None: + with pytest.raises(ValueError, match="Invalid auth type.*Bearer"): + jina_module.JinaAuth(_credentials(auth_type="basic")) + + +@pytest.mark.parametrize("credentials", [{"auth_type": "bearer", "config": {}}, {"auth_type": "bearer"}]) +def test_init_requires_api_key(jina_module: ModuleType, credentials: dict) -> None: + with pytest.raises(ValueError, match="No API key provided"): + jina_module.JinaAuth(credentials) + + +def test_prepare_headers_includes_bearer_api_key(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + assert auth._prepare_headers() == {"Content-Type": "application/json", "Authorization": "Bearer k"} + + +def test_post_request_calls_httpx(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + post_mock = MagicMock(name="httpx.post") + monkeypatch.setattr(jina_module.httpx, "post", post_mock) + + auth._post_request("https://r.jina.ai", {"url": "https://example.com"}, {"h": "v"}) + post_mock.assert_called_once_with("https://r.jina.ai", headers={"h": "v"}, json={"url": "https://example.com"}) + + +def test_validate_credentials_success(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + + response = MagicMock() + response.status_code = 200 + post_mock = MagicMock(return_value=response) + monkeypatch.setattr(jina_module.httpx, "post", post_mock) + + assert auth.validate_credentials() is True + post_mock.assert_called_once_with( + "https://r.jina.ai", + headers={"Content-Type": "application/json", "Authorization": "Bearer k"}, + json={"url": "https://example.com"}, + ) + + +def test_validate_credentials_non_200_raises_via_handle_error( + jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch +) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + + response = MagicMock() + response.status_code = 402 + response.json.return_value = {"error": "Payment required"} + monkeypatch.setattr(jina_module.httpx, "post", MagicMock(return_value=response)) + + with pytest.raises(Exception, match="Status code: 402.*Payment required"): + auth.validate_credentials() + + +@pytest.mark.parametrize("status_code", [402, 409, 500]) +def test_handle_error_statuses_use_response_json(jina_module: ModuleType, status_code: int) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = status_code + response.json.return_value = {"error": "boom"} + + with pytest.raises(Exception, match=f"Status code: {status_code}.*boom"): + auth._handle_error(response) + + +def test_handle_error_statuses_default_unknown_error(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 402 + response.json.return_value = {} + + with pytest.raises(Exception, match="Unknown error occurred"): + auth._handle_error(response) + + +def test_handle_error_with_text_json_body(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 403 + response.text = '{"error": "Forbidden"}' + + with pytest.raises(Exception, match="Status code: 403.*Forbidden"): + auth._handle_error(response) + + +def test_handle_error_with_text_json_body_missing_error(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 403 + response.text = "{}" + + with pytest.raises(Exception, match="Unknown error occurred"): + auth._handle_error(response) + + +def test_handle_error_without_text_raises_unexpected(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 404 + response.text = "" + + with pytest.raises(Exception, match="Unexpected error occurred.*404"): + auth._handle_error(response) + + +def test_validate_credentials_propagates_network_errors( + jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch +) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + monkeypatch.setattr(jina_module.httpx, "post", MagicMock(side_effect=httpx.ConnectError("boom"))) + + with pytest.raises(httpx.ConnectError, match="boom"): + auth.validate_credentials() diff --git a/api/tests/unit_tests/services/test_ops_service.py b/api/tests/unit_tests/services/test_ops_service.py new file mode 100644 index 0000000000..ab7b473790 --- /dev/null +++ b/api/tests/unit_tests/services/test_ops_service.py @@ -0,0 +1,381 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.config_entity import TracingProviderEnum +from models.model import App, TraceAppConfig +from services.ops_service import OpsService + + +class TestOpsService: + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.get_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + mock_db.session.query.assert_called_with(TraceAppConfig) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, None] + + # Act + result = OpsService.get_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + assert mock_db.session.query.call_count == 2 + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = None + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + # Act & Assert + with pytest.raises(ValueError, match="Tracing config cannot be None."): + OpsService.get_tracing_app_config("app_id", "arize") + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + ("provider", "default_url"), + [ + ("arize", "https://app.arize.com/"), + ("phoenix", "https://app.phoenix.arize.com/projects/"), + ("langsmith", "https://smith.langchain.com/"), + ("opik", "https://www.comet.com/opik/"), + ("weave", "https://wandb.ai/"), + ("aliyun", "https://arms.console.aliyun.com/"), + ("tencent", "https://console.cloud.tencent.com/apm"), + ("mlflow", "http://localhost:5000/"), + ("databricks", "https://www.databricks.com/"), + ], + ) + def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} + mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + + # Act + result = OpsService.get_tracing_app_config("app_id", provider) + + # Assert + assert result["tracing_config"]["project_url"] == default_url + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + "provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"] + ) + def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} + mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url" + + # Act + result = OpsService.get_tracing_app_config("app_id", provider) + + # Assert + assert result["tracing_config"]["project_url"] == "success_url" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" + + # Act + result = OpsService.get_tracing_app_config("app_id", "langfuse") + + # Assert + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + + # Act + result = OpsService.get_tracing_app_config("app_id", "langfuse") + + # Assert + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): + # Act + result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {}) + + # Assert + assert result == {"error": "Invalid tracing provider: invalid_provider"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.LANGFUSE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"}) + + # Assert + assert result == {"error": "Invalid Credentials"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + ("provider", "config"), + [ + (TracingProviderEnum.ARIZE, {}), + (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), + (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), + (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), + ], + ) + def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config): + # Arrange + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, config) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.LANGFUSE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + + # Act + result = OpsService.create_tracing_app_config( + "app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"} + ) + + # Assert + assert result == {"result": "success"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, None] + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + + # Act + # 'project' is in other_keys for Arize + # provide an empty string for the project in the tracing_config + # create_tracing_app_config will replace it with the default from the model + result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""}) + + # Assert + assert result == {"result": "success"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url" + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"} + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result == {"result": "success"} + mock_db.session.add.assert_called() + mock_db.session.commit.assert_called() + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): + # Act & Assert + with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): + OpsService.update_tracing_app_config("app_id", "invalid_provider", {}) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, None] + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + + # Act & Assert + with pytest.raises(ValueError, match="Invalid Credentials"): + OpsService.update_tracing_app_config("app_id", provider, {}) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + current_config.to_dict.return_value = {"some": "data"} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result == {"some": "data"} + mock_db.session.commit.assert_called_once() + + @patch("services.ops_service.db") + def test_delete_tracing_app_config_no_config(self, mock_db): + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.delete_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + + @patch("services.ops_service.db") + def test_delete_tracing_app_config_success(self, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.return_value = trace_config + + # Act + result = OpsService.delete_tracing_app_config("app_id", "arize") + + # Assert + assert result is True + mock_db.session.delete.assert_called_with(trace_config) + mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py new file mode 100644 index 0000000000..c7e1fed21f --- /dev/null +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -0,0 +1,1329 @@ +"""Unit tests for services.summary_index_service.""" + +from __future__ import annotations + +import sys +from dataclasses import dataclass +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import services.summary_index_service as summary_module +from services.summary_index_service import SummaryIndexService + + +@dataclass(frozen=True) +class _SessionContext: + session: MagicMock + + def __enter__(self) -> MagicMock: + return self.session + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock: + dataset = MagicMock(name="dataset") + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = indexing_technique + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding" + return dataset + + +def _segment(*, has_document: bool = True) -> MagicMock: + segment = MagicMock(name="segment") + segment.id = "seg-1" + segment.document_id = "doc-1" + segment.dataset_id = "dataset-1" + segment.content = "hello world" + segment.enabled = True + segment.status = "completed" + segment.position = 1 + if has_document: + doc = MagicMock(name="document") + doc.doc_language = "en" + doc.doc_form = "text_model" + segment.document = doc + else: + segment.document = None + return segment + + +def _summary_record(*, summary_content: str = "summary", node_id: str | None = None) -> MagicMock: + record = MagicMock(spec=summary_module.DocumentSegmentSummary, name="summary_record") + record.id = "sum-1" + record.dataset_id = "dataset-1" + record.document_id = "doc-1" + record.chunk_id = "seg-1" + record.summary_content = summary_content + record.summary_index_node_id = node_id + record.summary_index_node_hash = None + record.tokens = None + record.status = "generating" + record.error = None + record.enabled = True + record.created_at = datetime(2024, 1, 1, tzinfo=UTC) + record.updated_at = datetime(2024, 1, 1, tzinfo=UTC) + record.disabled_at = None + record.disabled_by = None + return record + + +def test_generate_summary_for_segment_passes_document_language(monkeypatch: pytest.MonkeyPatch) -> None: + usage = MagicMock() + usage.total_tokens = 10 + usage.prompt_tokens = 3 + usage.completion_tokens = 7 + + paragraph_module = SimpleNamespace( + ParagraphIndexProcessor=SimpleNamespace(generate_summary=MagicMock(return_value=("sum", usage))) + ) + monkeypatch.setitem( + sys.modules, + "core.rag.index_processor.processor.paragraph_index_processor", + paragraph_module, + ) + + segment = _segment(has_document=True) + dataset = _dataset() + + content, got_usage = SummaryIndexService.generate_summary_for_segment(segment, dataset, {"a": 1}) + assert content == "sum" + assert got_usage is usage + + paragraph_module.ParagraphIndexProcessor.generate_summary.assert_called_once() + _, kwargs = paragraph_module.ParagraphIndexProcessor.generate_summary.call_args + assert kwargs["document_language"] == "en" + + +def test_generate_summary_for_segment_raises_when_empty(monkeypatch: pytest.MonkeyPatch) -> None: + paragraph_module = SimpleNamespace( + ParagraphIndexProcessor=SimpleNamespace(generate_summary=MagicMock(return_value=("", MagicMock()))) + ) + monkeypatch.setitem( + sys.modules, + "core.rag.index_processor.processor.paragraph_index_processor", + paragraph_module, + ) + + with pytest.raises(ValueError, match="Generated summary is empty"): + SummaryIndexService.generate_summary_for_segment(_segment(), _dataset(), {"a": 1}) + + +def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytest.MonkeyPatch) -> None: + existing = _summary_record(summary_content="old", node_id="n1") + existing.enabled = False + existing.disabled_at = datetime(2024, 1, 1) + existing.disabled_by = "u" + + session = MagicMock(name="session") + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = existing + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + segment = _segment() + dataset = _dataset() + + result = SummaryIndexService.create_summary_record(segment, dataset, "new", status="generating") + assert result is existing + assert existing.summary_content == "new" + assert existing.status == "generating" + assert existing.enabled is True + assert existing.disabled_at is None + assert existing.disabled_by is None + assert existing.error is None + session.add.assert_called_once_with(existing) + session.flush.assert_called_once() + + +def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock(name="session") + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + record = SummaryIndexService.create_summary_record(_segment(), _dataset(), "new", status="generating") + assert record.dataset_id == "dataset-1" + assert record.chunk_id == "seg-1" + assert record.summary_content == "new" + assert record.enabled is True + session.add.assert_called_once() + session.flush.assert_called_once() + + +def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: + vector_cls = MagicMock() + monkeypatch.setattr(summary_module, "Vector", vector_cls) + SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy")) + vector_cls.assert_not_called() + + +def test_vectorize_summary_raises_for_blank_content() -> None: + with pytest.raises(ValueError, match="Summary content is empty"): + SummaryIndexService.vectorize_summary(_summary_record(summary_content=" "), _segment(), _dataset()) + + +def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [5] + model_manager = MagicMock() + model_manager.get_model_instance.return_value = embedding_model + monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + + vector_instance = MagicMock() + vector_instance.add_texts.side_effect = [RuntimeError("connection timeout"), None] + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + + session = MagicMock(name="provided_session") + merged = _summary_record(summary_content="sum") + session.merge.return_value = merged + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=session) + + assert vector_instance.add_texts.call_count == 2 + summary_module.time.sleep.assert_called_once() # type: ignore[attr-defined] + session.flush.assert_called_once() + assert summary.status == "completed" + assert summary.summary_index_node_id == "uuid-1" + assert summary.summary_index_node_hash == "hash-1" + assert summary.tokens == 5 + + +def test_vectorize_summary_without_session_creates_record_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id="old-node") + + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + # Force deletion branch to run and swallow delete failures. + vector_for_delete = MagicMock() + vector_for_delete.delete_by_ids.side_effect = RuntimeError("delete failed") + vector_for_add = MagicMock() + vector_for_add.add_texts.return_value = None + vector_cls = MagicMock(side_effect=[vector_for_delete, vector_for_add]) + monkeypatch.setattr(summary_module, "Vector", vector_cls) + + model_manager = MagicMock() + model_manager.get_model_instance.side_effect = RuntimeError("no model") + monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + + # New session used after vectorization succeeds (record not found by id nor chunk_id). + session = MagicMock(name="session") + q1 = MagicMock() + q1.filter_by.return_value = q1 + q1.first.side_effect = [None, None] + session.query.return_value = q1 + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + # One context for success path, no error handler session. + create_session_mock.assert_called() + session.add.assert_called() + session.commit.assert_called_once() + assert summary.status == "completed" + assert summary.summary_index_node_id == "old-node" # reused + + +def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + + vector_instance = MagicMock() + vector_instance.add_texts.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + + # error_session should find record and commit status update + error_session = MagicMock(name="error_session") + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = summary + error_session.query.return_value = q + + create_session_mock = MagicMock(return_value=_SessionContext(error_session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + assert summary.status == "error" + assert "Vectorization failed" in (summary.error or "") + error_session.commit.assert_called_once() + + +def test_batch_create_summary_records_no_segments_noop(monkeypatch: pytest.MonkeyPatch) -> None: + create_session_mock = MagicMock() + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + SummaryIndexService.batch_create_summary_records([], _dataset()) + create_session_mock.assert_not_called() + + +def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + s1 = _segment() + s2 = _segment() + s2.id = "seg-2" + s2.document_id = "doc-2" + + existing = _summary_record() + existing.chunk_id = "seg-2" + existing.enabled = False + + session = MagicMock() + query = MagicMock() + query.filter.return_value = query + query.all.return_value = [existing] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.batch_create_summary_records([s1, s2], dataset, status="not_started") + session.commit.assert_called_once() + assert existing.enabled is True + + +def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.update_summary_record_error(segment, dataset, "err") + assert record.status == "error" + assert record.error == "err" + session.commit.assert_called_once() + + +def test_generate_and_vectorize_summary_success(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", MagicMock(total_tokens=0))) + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + out = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert out is record + session.refresh.assert_called_once_with(record) + session.commit.assert_called() + + +def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", MagicMock(total_tokens=0))) + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert record.status == "error" + # Outer exception handler overwrites the error with the raw exception message. + assert record.error == "boom" + + +def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + vector_instance = MagicMock() + vector_instance.add_texts.return_value = None + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + existing = _summary_record(summary_content="old", node_id="old-node") + existing.id = "other-id" + session = MagicMock(name="session") + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, existing] # miss by id, hit by chunk_id + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + session.commit.assert_called_once() + assert existing.summary_index_node_id == "uuid-1" + + +def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + existing = _summary_record(summary_content="old", node_id="old-node") + session = MagicMock(name="session") + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = existing # hit by id + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + session.commit.assert_called_once() + assert existing.summary_index_node_hash == "hash-1" + + +def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + class _BadContext: + def __enter__(self): + return None + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + error_session = MagicMock() + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = summary + error_session.query.return_value = q + + create_session_mock = MagicMock(side_effect=[_BadContext(), _SessionContext(error_session)]) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + with pytest.raises(RuntimeError, match="Session should not be None"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + +def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + session = MagicMock() + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, None] # miss by id and chunk_id + session.query.return_value = q + + error_session = MagicMock() + eq = MagicMock() + eq.filter_by.return_value = eq + eq.first.return_value = summary + error_session.query.return_value = eq + + create_session_mock = MagicMock(side_effect=[_SessionContext(session), _SessionContext(error_session)]) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + # Force the created record to be None so the "should not be None" guard triggers. + monkeypatch.setattr(summary_module, "DocumentSegmentSummary", MagicMock(return_value=None)) + + with pytest.raises(RuntimeError, match="summary_record_in_session should not be None"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + +def test_vectorize_summary_error_handler_tries_chunk_id_lookup_and_can_warn_not_found( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + monkeypatch.setattr( + summary_module, + "Vector", + MagicMock(return_value=MagicMock(add_texts=MagicMock(side_effect=RuntimeError("boom")))), + ) + + error_session = MagicMock(name="error_session") + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, None] # not found by id, not found by chunk_id + error_session.query.return_value = q + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(error_session))), + ) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + # No record -> no commit in error session. + error_session.commit.assert_not_called() + + +def test_update_summary_record_error_warns_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + SummaryIndexService.update_summary_record_error(segment, dataset, "err") + logger_mock.warning.assert_called_once() + + +def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + usage = MagicMock(total_tokens=4, prompt_tokens=1, completion_tokens=3) + monkeypatch.setattr(SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", usage))) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + result = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert result.status in {"generating", "completed"} + logger_mock.info.assert_called() + + +def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset(indexing_technique="economy") + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + dataset = _dataset() + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == [] + + document.doc_form = "qa_model" + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + +def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + + seg1 = _segment() + seg2 = _segment() + seg2.id = "seg-2" + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [seg1, seg2] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr(SummaryIndexService, "batch_create_summary_records", MagicMock()) + monkeypatch.setattr( + SummaryIndexService, + "generate_and_vectorize_summary", + MagicMock(side_effect=[MagicMock(), RuntimeError("boom")]), + ) + update_err_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "update_summary_record_error", update_err_mock) + + records = SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) + assert len(records) == 1 + update_err_mock.assert_called_once() + + +def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + +def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chunks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + seg = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [seg] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + monkeypatch.setattr(SummaryIndexService, "batch_create_summary_records", MagicMock()) + monkeypatch.setattr(SummaryIndexService, "generate_and_vectorize_summary", MagicMock(return_value=MagicMock())) + + SummaryIndexService.generate_summaries_for_document( + dataset, + document, + {"enable": True}, + segment_ids=[seg.id], + only_parent_chunks=True, + ) + query.filter.assert_called() + + +def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary1 = _summary_record(summary_content="s", node_id="n1") + summary2 = _summary_record(summary_content="s", node_id=None) + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [summary1, summary2] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + summary_module, + "Vector", + MagicMock(return_value=MagicMock(delete_by_ids=MagicMock(side_effect=RuntimeError("boom")))), + ) + monkeypatch.setitem( + sys.modules, "libs.datetime_utils", SimpleNamespace(naive_utc_now=MagicMock(return_value=datetime(2024, 1, 1))) + ) + + SummaryIndexService.disable_summaries_for_segments(dataset, segment_ids=["seg-1"], disabled_by="u") + assert summary1.enabled is False + assert summary1.disabled_by == "u" + session.commit.assert_called_once() + + +def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setitem( + sys.modules, "libs.datetime_utils", SimpleNamespace(naive_utc_now=MagicMock(return_value=datetime(2024, 1, 1))) + ) + SummaryIndexService.disable_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_enable_summaries_for_segments_skips_non_high_quality() -> None: + SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy")) + + +def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary = _summary_record(summary_content="sum", node_id="n1") + summary.enabled = False + + segment = _segment() + segment.id = summary.chunk_id + segment.enabled = True + segment.status = "completed" + + session = MagicMock() + summary_query = MagicMock() + summary_query.filter_by.return_value = summary_query + summary_query.filter.return_value = summary_query + summary_query.all.return_value = [summary] + + seg_query = MagicMock() + seg_query.filter_by.return_value = seg_query + seg_query.first.return_value = segment + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + return summary_query + return seg_query + + session.query.side_effect = query_side_effect + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + vec_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vec_mock) + + SummaryIndexService.enable_summaries_for_segments(dataset, segment_ids=[summary.chunk_id]) + vec_mock.assert_called_once() + assert summary.enabled is True + session.commit.assert_called_once() + + +def test_enable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + SummaryIndexService.enable_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vectorize_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + summary1 = _summary_record(summary_content="sum", node_id="n1") + summary1.enabled = False + summary2 = _summary_record(summary_content="", node_id="n2") + summary2.enabled = False + summary3 = _summary_record(summary_content="sum3", node_id="n3") + summary3.enabled = False + + bad_segment = _segment() + bad_segment.enabled = False + bad_segment.status = "completed" + + good_segment = _segment() + good_segment.enabled = True + good_segment.status = "completed" + + session = MagicMock() + summary_query = MagicMock() + summary_query.filter_by.return_value = summary_query + summary_query.filter.return_value = summary_query + summary_query.all.return_value = [summary1, summary2, summary3] + + seg_query = MagicMock() + seg_query.filter_by.return_value = seg_query + seg_query.first.side_effect = [bad_segment, good_segment, good_segment] + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + return summary_query + return seg_query + + session.query.side_effect = query_side_effect + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + SummaryIndexService.enable_summaries_for_segments(dataset) + logger_mock.exception.assert_called_once() + session.commit.assert_called_once() + + +def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary = _summary_record(summary_content="sum", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [summary] + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids=[summary.chunk_id]) + vector_instance.delete_by_ids.assert_called_once_with(["n1"]) + session.delete.assert_called_once_with(summary) + session.commit.assert_called_once() + + +def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + SummaryIndexService.delete_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_update_summary_for_segment_skip_conditions() -> None: + assert ( + SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None + ) + seg = _segment(has_document=True) + seg.document.doc_form = "qa_model" + assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None + + +def test_update_summary_for_segment_empty_content_deletes_existing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, " ") is None + vector_instance.delete_by_ids.assert_called_once_with(["n1"]) + session.delete.assert_called_once_with(record) + session.commit.assert_called_once() + + +def test_update_summary_for_segment_empty_content_delete_vector_warns(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vector_instance = MagicMock() + vector_instance.delete_by_ids.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, "") is None + logger_mock.warning.assert_called() + + +def test_update_summary_for_segment_empty_content_no_record_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, " ") is None + + +def test_update_summary_for_segment_updates_existing_and_vectorizes(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vectorize_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new summary") + assert out is record + vectorize_mock.assert_called_once() + session.refresh.assert_called_once_with(record) + session.commit.assert_called() + + +def test_update_summary_for_segment_existing_vector_delete_warns(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vector_instance = MagicMock() + vector_instance.delete_by_ids.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + logger_mock.warning.assert_called() + + +def test_update_summary_for_segment_existing_vectorize_failure_returns_error_record( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out is record + assert out.status == "error" + assert "Vectorization failed" in (out.error or "") + + +def test_update_summary_for_segment_new_record_success(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + created = _summary_record(summary_content="new", node_id=None) + monkeypatch.setattr(SummaryIndexService, "create_summary_record", MagicMock(return_value=created)) + session.merge.return_value = created + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out is created + session.refresh.assert_called() + session.commit.assert_called() + + +def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + session.flush.side_effect = RuntimeError("flush boom") + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + with pytest.raises(RuntimeError, match="flush boom"): + SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert record.status == "error" + assert record.error == "flush boom" + session.commit.assert_called() + + +def test_get_segment_summary_and_document_summaries(monkeypatch: pytest.MonkeyPatch) -> None: + record = _summary_record(summary_content="sum", node_id="n1") + session = MagicMock() + + q1 = MagicMock() + q1.where.return_value = q1 + q1.first.return_value = record + + q2 = MagicMock() + q2.filter.return_value = q2 + q2.all.return_value = [record] + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + # first call used by get_segment_summary, second by get_document_summaries + if not hasattr(query_side_effect, "_called"): + query_side_effect._called = True # type: ignore[attr-defined] + return q1 + return q2 + return MagicMock() + + session.query.side_effect = query_side_effect + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.get_segment_summary("seg-1", "dataset-1") is record + assert SummaryIndexService.get_document_summaries("doc-1", "dataset-1", segment_ids=["seg-1"]) == [record] + + +def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> None: + record1 = _summary_record() + record1.chunk_id = "seg-1" + record2 = _summary_record() + record2.chunk_id = "seg-2" + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [record1, record2] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + out = SummaryIndexService.get_segments_summaries(["seg-1", "seg-2"], "dataset-1") + assert set(out.keys()) == {"seg-1", "seg-2"} + + +def test_get_document_summary_index_status_no_segments_returns_none(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") is None + + +def test_get_documents_summary_index_status_empty_input(monkeypatch: pytest.MonkeyPatch) -> None: + assert SummaryIndexService.get_documents_summary_index_status([], "dataset-1", "tenant-1") == {} + + +def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status="completed")}), + ) + result = SummaryIndexService.get_documents_summary_index_status(["doc-1"], "dataset-1", "tenant-1") + assert result["doc-1"] is None + + +def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_error_record( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + created = _summary_record(summary_content="new", node_id=None) + monkeypatch.setattr(SummaryIndexService, "create_summary_record", MagicMock(return_value=created)) + session.merge.return_value = created + + vectorize_mock = MagicMock(side_effect=RuntimeError("boom")) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out.status == "error" + assert "Vectorization failed" in (out.error or "") + + +def test_get_segments_summaries_empty_list() -> None: + assert SummaryIndexService.get_segments_summaries([], "dataset-1") == {} + + +def test_get_document_summary_index_status_and_documents_status(monkeypatch: pytest.MonkeyPatch) -> None: + seg_row = SimpleNamespace(id="seg-1", document_id="doc-1") + session = MagicMock() + query = MagicMock() + query.where.return_value = query + query.all.return_value = [SimpleNamespace(id="seg-1")] + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status="generating")}), + ) + assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING" + + # Multiple docs + query2 = MagicMock() + query2.where.return_value = query2 + query2.all.return_value = [seg_row] + session2 = MagicMock() + session2.query.return_value = query2 + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session2))), + ) + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status="not_started")}), + ) + result = SummaryIndexService.get_documents_summary_index_status(["doc-1", "doc-2"], "dataset-1", "tenant-1") + assert result["doc-1"] == "SUMMARIZING" + assert result["doc-2"] is None + + +def test_get_document_summary_status_detail_counts_and_previews(monkeypatch: pytest.MonkeyPatch) -> None: + segment1 = _segment() + segment1.id = "seg-1" + segment1.position = 1 + segment2 = _segment() + segment2.id = "seg-2" + segment2.position = 2 + + summary1 = _summary_record(summary_content="x" * 150, node_id="n1") + summary1.chunk_id = "seg-1" + summary1.status = "completed" + summary1.error = None + summary1.created_at = datetime(2024, 1, 1, tzinfo=UTC) + summary1.updated_at = datetime(2024, 1, 2, tzinfo=UTC) + + segment_service = SimpleNamespace(get_segments_by_document_and_dataset=MagicMock(return_value=[segment1, segment2])) + monkeypatch.setitem(sys.modules, "services.dataset_service", SimpleNamespace(SegmentService=segment_service)) + + monkeypatch.setattr(SummaryIndexService, "get_document_summaries", MagicMock(return_value=[summary1])) + + detail = SummaryIndexService.get_document_summary_status_detail("doc-1", "dataset-1") + assert detail["total_segments"] == 2 + assert detail["summary_status"]["completed"] == 1 + assert detail["summary_status"]["not_started"] == 1 + assert detail["summaries"][0]["summary_preview"].endswith("...") + assert detail["summaries"][1]["status"] == "not_started" diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py new file mode 100644 index 0000000000..7b0103a2a1 --- /dev/null +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -0,0 +1,704 @@ +"""Unit tests for `api/services/vector_service.py`.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import services.vector_service as vector_service_module +from services.vector_service import VectorService + + +@dataclass(frozen=True) +class _UploadFileStub: + id: str + name: str + + +@dataclass(frozen=True) +class _ChildDocStub: + page_content: str + metadata: dict[str, Any] + + +@dataclass +class _ParentDocStub: + children: list[_ChildDocStub] + + +def _make_dataset( + *, + indexing_technique: str = "high_quality", + doc_form: str = "text_model", + tenant_id: str = "tenant-1", + dataset_id: str = "dataset-1", + is_multimodal: bool = False, + embedding_model_provider: str | None = "openai", + embedding_model: str = "text-embedding", +) -> MagicMock: + dataset = MagicMock(name="dataset") + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.doc_form = doc_form + dataset.indexing_technique = indexing_technique + dataset.is_multimodal = is_multimodal + dataset.embedding_model_provider = embedding_model_provider + dataset.embedding_model = embedding_model + return dataset + + +def _make_segment( + *, + segment_id: str = "seg-1", + tenant_id: str = "tenant-1", + dataset_id: str = "dataset-1", + document_id: str = "doc-1", + content: str = "hello", + index_node_id: str = "node-1", + index_node_hash: str = "hash-1", + attachments: list[dict[str, str]] | None = None, +) -> MagicMock: + segment = MagicMock(name="segment") + segment.id = segment_id + segment.tenant_id = tenant_id + segment.dataset_id = dataset_id + segment.document_id = document_id + segment.content = content + segment.index_node_id = index_node_id + segment.index_node_hash = index_node_hash + segment.attachments = attachments or [] + return segment + + +def _mock_db_session_for_update_multimodel(*, upload_files: list[_UploadFileStub] | None) -> MagicMock: + session = MagicMock(name="session") + + binding_query = MagicMock(name="binding_query") + binding_query.where.return_value = binding_query + binding_query.delete.return_value = 1 + + upload_query = MagicMock(name="upload_query") + upload_query.where.return_value = upload_query + upload_query.all.return_value = upload_files or [] + + def query_side_effect(model: object) -> MagicMock: + if model is vector_service_module.SegmentAttachmentBinding: + return binding_query + if model is vector_service_module.UploadFile: + return upload_query + return MagicMock(name=f"query({model})") + + session.query.side_effect = query_side_effect + db_mock = MagicMock(name="db") + db_mock.session = session + return db_mock + + +def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(is_multimodal=False) + segment = _make_segment() + + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock(name="IndexProcessorFactory-instance") + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + + index_processor.load.assert_called_once() + args, kwargs = index_processor.load.call_args + assert args[0] == dataset + assert len(args[1]) == 1 + assert args[2] is None + assert kwargs["with_keywords"] is True + assert kwargs["keywords_list"] == [["k1"]] + + +def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(is_multimodal=True) + segment = _make_segment( + attachments=[ + {"id": "img-1", "name": "a.png"}, + {"id": "img-2", "name": "b.png"}, + ] + ) + + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock(name="IndexProcessorFactory-instance") + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + + assert index_processor.load.call_count == 2 + first_args, first_kwargs = index_processor.load.call_args_list[0] + assert first_args[0] == dataset + assert len(first_args[1]) == 1 + assert first_kwargs["with_keywords"] is True + + second_args, second_kwargs = index_processor.load.call_args_list[1] + assert second_args[0] == dataset + assert second_args[1] == [] + assert len(second_args[2]) == 2 + assert second_kwargs["with_keywords"] is False + + +def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset() + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector(None, [], dataset, "text_model") + index_processor.load.assert_not_called() + + +def _mock_parent_child_queries( + *, + dataset_document: object | None, + processing_rule: object | None, +) -> MagicMock: + session = MagicMock(name="session") + + doc_query = MagicMock(name="doc_query") + doc_query.filter_by.return_value = doc_query + doc_query.first.return_value = dataset_document + + rule_query = MagicMock(name="rule_query") + rule_query.where.return_value = rule_query + rule_query.first.return_value = processing_rule + + def query_side_effect(model: object) -> MagicMock: + if model is vector_service_module.DatasetDocument: + return doc_query + if model is vector_service_module.DatasetProcessRule: + return rule_query + return MagicMock(name=f"query({model})") + + session.query.side_effect = query_side_effect + db_mock = MagicMock(name="db") + db_mock.session = session + return db_mock + + +def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_explicit_model( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + embedding_model_provider="openai", + indexing_technique="high_quality", + ) + segment = _make_segment() + + dataset_document = MagicMock(name="dataset_document") + dataset_document.id = segment.document_id + dataset_document.dataset_process_rule_id = "rule-1" + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock(name="processing_rule") + processing_rule.to_dict.return_value = {"rules": {}} + + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + embedding_model_instance = MagicMock(name="embedding_model_instance") + model_manager_instance = MagicMock(name="model_manager_instance") + model_manager_instance.get_model_instance.return_value = embedding_model_instance + monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + + generate_child_chunks_mock = MagicMock() + monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + model_manager_instance.get_model_instance.assert_called_once() + generate_child_chunks_mock.assert_called_once_with( + segment, dataset_document, dataset, embedding_model_instance, processing_rule, False + ) + index_processor.load.assert_not_called() + + +def test_create_segments_vector_parent_child_uses_default_embedding_model_when_provider_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + embedding_model_provider=None, + indexing_technique="high_quality", + ) + segment = _make_segment() + + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + embedding_model_instance = MagicMock() + model_manager_instance = MagicMock() + model_manager_instance.get_default_model_instance.return_value = embedding_model_instance + monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + + generate_child_chunks_mock = MagicMock() + monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + model_manager_instance.get_default_model_instance.assert_called_once() + generate_child_chunks_mock.assert_called_once() + + +def test_create_segments_vector_parent_child_missing_document_logs_warning_and_continues( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX) + segment = _make_segment() + + processing_rule = MagicMock() + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=None, processing_rule=processing_rule), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + logger_mock.warning.assert_called_once() + index_processor.load.assert_not_called() + + +def test_create_segments_vector_parent_child_missing_processing_rule_raises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX) + segment = _make_segment() + + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=None), + ) + + with pytest.raises(ValueError, match="No processing rule found"): + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + +def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + indexing_technique="economy", + ) + segment = _make_segment() + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + processing_rule = MagicMock() + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + with pytest.raises(ValueError, match="not high quality"): + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + +def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + segment = _make_segment() + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.update_segment_vector(["k"], segment, dataset) + + vector_instance.delete_by_ids.assert_called_once_with([segment.index_node_id]) + vector_instance.add_texts.assert_called_once() + add_args, add_kwargs = vector_instance.add_texts.call_args + assert len(add_args[0]) == 1 + assert add_kwargs["duplicate_check"] is True + + +def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + segment = _make_segment() + + keyword_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance)) + + VectorService.update_segment_vector(["a", "b"], segment, dataset) + + keyword_instance.delete_by_ids.assert_called_once_with([segment.index_node_id]) + keyword_instance.add_texts.assert_called_once() + args, kwargs = keyword_instance.add_texts.call_args + assert len(args[0]) == 1 + assert kwargs["keywords_list"] == [["a", "b"]] + + +def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + segment = _make_segment() + + keyword_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance)) + + VectorService.update_segment_vector(None, segment, dataset) + keyword_instance.add_texts.assert_called_once() + _, kwargs = keyword_instance.add_texts.call_args + assert "keywords_list" not in kwargs + + +def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1") + segment = _make_segment(segment_id="seg-1") + + dataset_document = MagicMock() + dataset_document.id = segment.document_id + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + child1 = _ChildDocStub(page_content="c1", metadata={"doc_id": "c1-id", "doc_hash": "c1-h"}) + child2 = _ChildDocStub(page_content="c2", metadata={"doc_id": "c2-id", "doc_hash": "c2-h"}) + transformed = [_ParentDocStub(children=[child1, child2])] + + index_processor = MagicMock() + index_processor.transform.return_value = transformed + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + child_chunk_ctor = MagicMock(side_effect=lambda **kwargs: kwargs) + monkeypatch.setattr(vector_service_module, "ChildChunk", child_chunk_ctor) + + db_mock = MagicMock() + db_mock.session.add = MagicMock() + db_mock.session.commit = MagicMock() + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.generate_child_chunks( + segment=segment, + dataset_document=dataset_document, + dataset=dataset, + embedding_model_instance=MagicMock(), + processing_rule=processing_rule, + regenerate=True, + ) + + index_processor.clean.assert_called_once() + _, transform_kwargs = index_processor.transform.call_args + assert transform_kwargs["process_rule"]["rules"]["parent_mode"] == vector_service_module.ParentMode.FULL_DOC + index_processor.load.assert_called_once() + assert db_mock.session.add.call_count == 2 + db_mock.session.commit.assert_called_once() + + +def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form="text_model") + segment = _make_segment() + dataset_document = MagicMock() + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + index_processor = MagicMock() + index_processor.transform.return_value = [_ParentDocStub(children=[])] + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + db_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.generate_child_chunks( + segment=segment, + dataset_document=dataset_document, + dataset=dataset, + embedding_model_instance=MagicMock(), + processing_rule=processing_rule, + regenerate=False, + ) + + index_processor.load.assert_not_called() + db_mock.session.add.assert_not_called() + db_mock.session.commit.assert_called_once() + + +def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + child_chunk = MagicMock() + child_chunk.content = "child" + child_chunk.index_node_id = "id" + child_chunk.index_node_hash = "h" + child_chunk.document_id = "doc-1" + child_chunk.dataset_id = "dataset-1" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.create_child_chunk_vector(child_chunk, dataset) + vector_instance.add_texts.assert_called_once() + + +def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + vector_cls = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + + child_chunk = MagicMock() + child_chunk.content = "child" + child_chunk.index_node_id = "id" + child_chunk.index_node_hash = "h" + child_chunk.document_id = "doc-1" + child_chunk.dataset_id = "dataset-1" + + VectorService.create_child_chunk_vector(child_chunk, dataset) + vector_cls.assert_not_called() + + +def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + + new_chunk = MagicMock() + new_chunk.content = "n" + new_chunk.index_node_id = "nid" + new_chunk.index_node_hash = "nh" + new_chunk.document_id = "d" + new_chunk.dataset_id = "ds" + + upd_chunk = MagicMock() + upd_chunk.content = "u" + upd_chunk.index_node_id = "uid" + upd_chunk.index_node_hash = "uh" + upd_chunk.document_id = "d" + upd_chunk.dataset_id = "ds" + + del_chunk = MagicMock() + del_chunk.index_node_id = "did" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.update_child_chunk_vector([new_chunk], [upd_chunk], [del_chunk], dataset) + + vector_instance.delete_by_ids.assert_called_once_with(["uid", "did"]) + vector_instance.add_texts.assert_called_once() + docs = vector_instance.add_texts.call_args.args[0] + assert len(docs) == 2 + + +def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + vector_cls = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + VectorService.update_child_chunk_vector([], [], [], dataset) + vector_cls.assert_not_called() + + +def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset() + child_chunk = MagicMock() + child_chunk.index_node_id = "cid" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.delete_child_chunk_vector(child_chunk, dataset) + vector_instance.delete_by_ids.assert_called_once_with(["cid"]) + + +# --------------------------------------------------------------------------- +# update_multimodel_vector (missing coverage in previous suites) +# --------------------------------------------------------------------------- + + +def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy", is_multimodal=True) + segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}]) + + vector_cls = MagicMock() + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["a"], dataset=dataset) + vector_cls.assert_not_called() + db_mock.session.query.assert_not_called() + + +def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}]) + + vector_cls = MagicMock() + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["b", "a"], dataset=dataset) + vector_cls.assert_not_called() + db_mock.session.query.assert_not_called() + + +def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}]) + + vector_instance = MagicMock(name="vector_instance") + vector_cls = MagicMock(return_value=vector_instance) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=[], dataset=dataset) + + vector_cls.assert_called_once_with(dataset=dataset) + vector_instance.delete_by_ids.assert_called_once_with(["old-1", "old-2"]) + db_mock.session.query.assert_called_once_with(vector_service_module.SegmentAttachmentBinding) + db_mock.session.commit.assert_called_once() + db_mock.session.add_all.assert_not_called() + vector_instance.add_texts.assert_not_called() + + +def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["new-1"], dataset=dataset) + + db_mock.session.commit.assert_called_once() + db_mock.session.add_all.assert_not_called() + vector_instance.add_texts.assert_not_called() + + +def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + binding_ctor = MagicMock(side_effect=lambda **kwargs: kwargs) + monkeypatch.setattr(vector_service_module, "SegmentAttachmentBinding", binding_ctor) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1", "missing"], dataset=dataset) + + logger_mock.warning.assert_called_once() + db_mock.session.add_all.assert_called_once() + bindings = db_mock.session.add_all.call_args.args[0] + assert len(bindings) == 1 + assert bindings[0]["attachment_id"] == "file-1" + + vector_instance.add_texts.assert_called_once() + documents = vector_instance.add_texts.call_args.args[0] + assert len(documents) == 1 + assert documents[0].page_content == "img.png" + assert documents[0].metadata["doc_id"] == "file-1" + db_mock.session.commit.assert_called_once() + + +def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + monkeypatch.setattr( + vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) + ) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) + + vector_instance.delete_by_ids.assert_not_called() + vector_instance.add_texts.assert_not_called() + db_mock.session.add_all.assert_called_once() + db_mock.session.commit.assert_called_once() + + +def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + db_mock.session.commit.side_effect = RuntimeError("boom") + monkeypatch.setattr(vector_service_module, "db", db_mock) + monkeypatch.setattr( + vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) + ) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + with pytest.raises(RuntimeError, match="boom"): + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) + + logger_mock.exception.assert_called_once() + db_mock.session.rollback.assert_called_once() diff --git a/api/tests/unit_tests/services/test_website_service.py b/api/tests/unit_tests/services/test_website_service.py new file mode 100644 index 0000000000..38d94f4736 --- /dev/null +++ b/api/tests/unit_tests/services/test_website_service.py @@ -0,0 +1,718 @@ +"""Unit tests for services.website_service. + +Focuses on provider dispatching, argument validation, and provider-specific branches +without making any real network/storage/redis calls. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +import services.website_service as website_service_module +from services.website_service import ( + CrawlOptions, + WebsiteCrawlApiRequest, + WebsiteCrawlStatusApiRequest, + WebsiteService, +) + + +@dataclass(frozen=True) +class _DummyHttpxResponse: + payload: dict[str, Any] + + def json(self) -> dict[str, Any]: + return self.payload + + +@pytest.fixture(autouse=True) +def stub_current_user(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module, + "current_user", + type("User", (), {"current_tenant_id": "tenant-1"})(), + ) + + +def test_crawl_options_include_exclude_paths() -> None: + options = CrawlOptions(includes="a,b", excludes="x,y") + assert options.get_include_paths() == ["a", "b"] + assert options.get_exclude_paths() == ["x", "y"] + + empty = CrawlOptions(includes=None, excludes=None) + assert empty.get_include_paths() == [] + assert empty.get_exclude_paths() == [] + + +def test_website_crawl_api_request_from_args_valid_and_to_crawl_request() -> None: + args = { + "provider": "firecrawl", + "url": "https://example.com", + "options": { + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a,b", + "excludes": "x", + "prompt": "hi", + "max_depth": 3, + "use_sitemap": False, + }, + } + + api_req = WebsiteCrawlApiRequest.from_args(args) + crawl_req = api_req.to_crawl_request() + + assert crawl_req.provider == "firecrawl" + assert crawl_req.url == "https://example.com" + assert crawl_req.options.limit == 2 + assert crawl_req.options.crawl_sub_pages is True + assert crawl_req.options.only_main_content is True + assert crawl_req.options.get_include_paths() == ["a", "b"] + assert crawl_req.options.get_exclude_paths() == ["x"] + assert crawl_req.options.prompt == "hi" + assert crawl_req.options.max_depth == 3 + assert crawl_req.options.use_sitemap is False + + +@pytest.mark.parametrize( + ("args", "missing_msg"), + [ + ({}, "Provider is required"), + ({"provider": "firecrawl"}, "URL is required"), + ({"provider": "firecrawl", "url": "https://example.com"}, "Options are required"), + ], +) +def test_website_crawl_api_request_from_args_requires_fields(args: dict, missing_msg: str) -> None: + with pytest.raises(ValueError, match=missing_msg): + WebsiteCrawlApiRequest.from_args(args) + + +def test_website_crawl_status_api_request_from_args_requires_fields() -> None: + with pytest.raises(ValueError, match="Provider is required"): + WebsiteCrawlStatusApiRequest.from_args({}, job_id="job-1") + + with pytest.raises(ValueError, match="Job ID is required"): + WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="") + + req = WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="job-1") + assert req.provider == "firecrawl" + assert req.job_id == "job-1" + + +def test_get_credentials_and_config_selects_plugin_id_and_key_firecrawl(monkeypatch: pytest.MonkeyPatch) -> None: + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k", "base_url": "b"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + api_key, config = WebsiteService._get_credentials_and_config("tenant-1", "firecrawl") + assert api_key == "k" + assert config["base_url"] == "b" + + service_instance.get_datasource_credentials.assert_called_once_with( + tenant_id="tenant-1", + provider="firecrawl", + plugin_id="langgenius/firecrawl_datasource", + ) + + +@pytest.mark.parametrize( + ("provider", "plugin_id"), + [ + ("watercrawl", "langgenius/watercrawl_datasource"), + ("jinareader", "langgenius/jina_datasource"), + ], +) +def test_get_credentials_and_config_selects_plugin_id_and_key_api_key( + monkeypatch: pytest.MonkeyPatch, provider: str, plugin_id: str +) -> None: + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"api_key": "enc-key", "base_url": "b"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + api_key, config = WebsiteService._get_credentials_and_config("tenant-1", provider) + assert api_key == "enc-key" + assert config["base_url"] == "b" + + service_instance.get_datasource_credentials.assert_called_once_with( + tenant_id="tenant-1", + provider=provider, + plugin_id=plugin_id, + ) + + +def test_get_credentials_and_config_rejects_invalid_provider() -> None: + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService._get_credentials_and_config("tenant-1", "unknown") + + +def test_get_credentials_and_config_hits_unreachable_guard_branch(monkeypatch: pytest.MonkeyPatch) -> None: + class FlakyProvider: + def __init__(self) -> None: + self._eq_calls = 0 + + def __hash__(self) -> int: + return 1 + + def __eq__(self, other: object) -> bool: + if other == "firecrawl": + self._eq_calls += 1 + return self._eq_calls == 1 + return False + + def __repr__(self) -> str: + return "FlakyProvider()" + + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService._get_credentials_and_config("tenant-1", FlakyProvider()) # type: ignore[arg-type] + + +def test_get_decrypted_api_key_requires_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", MagicMock()) + with pytest.raises(ValueError, match="API key not found in configuration"): + WebsiteService._get_decrypted_api_key("tenant-1", {}) + + +def test_get_decrypted_api_key_decrypts(monkeypatch: pytest.MonkeyPatch) -> None: + decrypt_mock = MagicMock(return_value="plain") + monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", decrypt_mock) + + assert WebsiteService._get_decrypted_api_key("tenant-1", {"api_key": "enc"}) == "plain" + decrypt_mock.assert_called_once_with(tenant_id="tenant-1", token="enc") + + +def test_document_create_args_validate_wraps_error_message() -> None: + with pytest.raises(ValueError, match=r"^Invalid arguments: Provider is required$"): + WebsiteService.document_create_args_validate({}) + + +def test_crawl_url_dispatches_by_provider(monkeypatch: pytest.MonkeyPatch) -> None: + api_request = WebsiteCrawlApiRequest(provider="firecrawl", url="https://example.com", options={"limit": 1}) + crawl_request = api_request.to_crawl_request() + + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + firecrawl_mock = MagicMock(return_value={"status": "active", "job_id": "j1"}) + monkeypatch.setattr(WebsiteService, "_crawl_with_firecrawl", firecrawl_mock) + + result = WebsiteService.crawl_url(api_request) + + assert result == {"status": "active", "job_id": "j1"} + firecrawl_mock.assert_called_once() + assert firecrawl_mock.call_args.kwargs["request"] == crawl_request + + +@pytest.mark.parametrize( + ("provider", "method_name"), + [ + ("watercrawl", "_crawl_with_watercrawl"), + ("jinareader", "_crawl_with_jinareader"), + ], +) +def test_crawl_url_dispatches_other_providers(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None: + api_request = WebsiteCrawlApiRequest(provider=provider, url="https://example.com", options={"limit": 1}) + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + + impl_mock = MagicMock(return_value={"status": "active"}) + monkeypatch.setattr(WebsiteService, method_name, impl_mock) + + assert WebsiteService.crawl_url(api_request) == {"status": "active"} + impl_mock.assert_called_once() + + +def test_crawl_url_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + api_request = WebsiteCrawlApiRequest(provider="bad", url="https://example.com", options={"limit": 1}) + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.crawl_url(api_request) + + +def test_crawl_with_firecrawl_builds_params_single_page_and_sets_redis(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock(name="FirecrawlApp-instance") + firecrawl_instance.crawl_url.return_value = "job-1" + firecrawl_cls = MagicMock(return_value=firecrawl_instance) + monkeypatch.setattr(website_service_module, "FirecrawlApp", firecrawl_cls) + + redis_mock = MagicMock() + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + fixed_now = datetime(2024, 1, 1, tzinfo=UTC) + with patch.object(website_service_module.datetime, "datetime") as datetime_mock: + datetime_mock.now.return_value = fixed_now + + req = WebsiteCrawlApiRequest( + provider="firecrawl", url="https://example.com", options={"limit": 5} + ).to_crawl_request() + req.options.crawl_sub_pages = False + req.options.only_main_content = True + + result = WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": "b"}) + + assert result == {"status": "active", "job_id": "job-1"} + + firecrawl_cls.assert_called_once_with(api_key="k", base_url="b") + firecrawl_instance.crawl_url.assert_called_once() + _, params = firecrawl_instance.crawl_url.call_args.args + assert params["limit"] == 1 + assert params["includePaths"] == [] + assert params["excludePaths"] == [] + assert params["scrapeOptions"] == {"onlyMainContent": True} + + redis_mock.setex.assert_called_once() + key, ttl, value = redis_mock.setex.call_args.args + assert key == "website_crawl_job-1" + assert ttl == 3600 + assert float(value) == pytest.approx(fixed_now.timestamp(), rel=0, abs=1e-6) + + +def test_crawl_with_firecrawl_builds_params_multi_page_including_prompt(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock(name="FirecrawlApp-instance") + firecrawl_instance.crawl_url.return_value = "job-2" + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + monkeypatch.setattr(website_service_module, "redis_client", MagicMock()) + + req = WebsiteCrawlApiRequest( + provider="firecrawl", + url="https://example.com", + options={ + "crawl_sub_pages": True, + "limit": 3, + "only_main_content": False, + "includes": "a,b", + "excludes": "x", + "prompt": "use this", + }, + ).to_crawl_request() + + WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": None}) + _, params = firecrawl_instance.crawl_url.call_args.args + assert params["includePaths"] == ["a", "b"] + assert params["excludePaths"] == ["x"] + assert params["limit"] == 3 + assert params["scrapeOptions"] == {"onlyMainContent": False} + assert params["prompt"] == "use this" + + +def test_crawl_with_watercrawl_passes_options_dict(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.crawl_url.return_value = {"status": "active", "job_id": "w1"} + provider_cls = MagicMock(return_value=provider_instance) + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", provider_cls) + + req = WebsiteCrawlApiRequest( + provider="watercrawl", + url="https://example.com", + options={ + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a", + "excludes": None, + "max_depth": 5, + "use_sitemap": False, + }, + ).to_crawl_request() + + result = WebsiteService._crawl_with_watercrawl(request=req, api_key="k", config={"base_url": "b"}) + assert result == {"status": "active", "job_id": "w1"} + + provider_cls.assert_called_once_with(api_key="k", base_url="b") + provider_instance.crawl_url.assert_called_once_with( + url="https://example.com", + options={ + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a", + "excludes": None, + "max_depth": 5, + "use_sitemap": False, + }, + ) + + +def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPatch) -> None: + get_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"title": "t"}})) + monkeypatch.setattr(website_service_module.httpx, "get", get_mock) + + req = WebsiteCrawlApiRequest( + provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} + ).to_crawl_request() + req.options.crawl_sub_pages = False + + result = WebsiteService._crawl_with_jinareader(request=req, api_key="k") + assert result == {"status": "active", "data": {"title": "t"}} + get_mock.assert_called_once() + + +def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + req = WebsiteCrawlApiRequest( + provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} + ).to_crawl_request() + req.options.crawl_sub_pages = False + + with pytest.raises(ValueError, match="Failed to crawl:"): + WebsiteService._crawl_with_jinareader(request=req, api_key="k") + + +def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"taskId": "t1"}})) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + req = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={"crawl_sub_pages": True, "limit": 5, "use_sitemap": True}, + ).to_crawl_request() + req.options.crawl_sub_pages = True + + result = WebsiteService._crawl_with_jinareader(request=req, api_key="k") + assert result == {"status": "active", "job_id": "t1"} + post_mock.assert_called_once() + + +def test_crawl_with_jinareader_multi_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module.httpx, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400})) + ) + req = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={"crawl_sub_pages": True, "limit": 2, "use_sitemap": False}, + ).to_crawl_request() + req.options.crawl_sub_pages = True + + with pytest.raises(ValueError, match="Failed to crawl$"): + WebsiteService._crawl_with_jinareader(request=req, api_key="k") + + +def test_get_crawl_status_dispatches(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + firecrawl_status = MagicMock(return_value={"status": "active"}) + monkeypatch.setattr(WebsiteService, "_get_firecrawl_status", firecrawl_status) + + result = WebsiteService.get_crawl_status("job-1", "firecrawl") + assert result == {"status": "active"} + firecrawl_status.assert_called_once_with("job-1", "k", {"base_url": "b"}) + + watercrawl_status = MagicMock(return_value={"status": "active", "job_id": "w"}) + monkeypatch.setattr(WebsiteService, "_get_watercrawl_status", watercrawl_status) + assert WebsiteService.get_crawl_status("job-2", "watercrawl") == {"status": "active", "job_id": "w"} + watercrawl_status.assert_called_once_with("job-2", "k", {"base_url": "b"}) + + jinareader_status = MagicMock(return_value={"status": "active", "job_id": "j"}) + monkeypatch.setattr(WebsiteService, "_get_jinareader_status", jinareader_status) + assert WebsiteService.get_crawl_status("job-3", "jinareader") == {"status": "active", "job_id": "j"} + jinareader_status.assert_called_once_with("job-3", "k") + + +def test_get_crawl_status_typed_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_status_typed(WebsiteCrawlStatusApiRequest(provider="bad", job_id="j")) + + +def test_get_firecrawl_status_adds_time_consuming_when_completed_and_cached(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 2, "current": 2, "data": []} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + redis_mock = MagicMock() + redis_mock.get.return_value = b"100.0" + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + with patch.object(website_service_module.datetime, "datetime") as datetime_mock: + datetime_mock.now.return_value = datetime.fromtimestamp(105.0, tz=UTC) + result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": "b"}) + + assert result["status"] == "completed" + assert result["time_consuming"] == "5.00" + redis_mock.delete.assert_called_once_with("website_crawl_job-1") + + +def test_get_firecrawl_status_completed_without_cache_does_not_add_time(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + redis_mock = MagicMock() + redis_mock.get.return_value = None + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": None}) + assert result["status"] == "completed" + assert "time_consuming" not in result + redis_mock.delete.assert_not_called() + + +def test_get_watercrawl_status_delegates(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.get_crawl_status.return_value = {"status": "active", "job_id": "w1"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + assert WebsiteService._get_watercrawl_status("job-1", "k", {"base_url": "b"}) == { + "status": "active", + "job_id": "w1", + } + provider_instance.get_crawl_status.assert_called_once_with("job-1") + + +def test_get_jinareader_status_active(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock( + return_value=_DummyHttpxResponse( + { + "data": { + "status": "active", + "urls": ["a", "b"], + "processed": {"a": {}}, + "failed": {"b": {}}, + "duration": 3000, + } + } + ) + ) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + result = WebsiteService._get_jinareader_status("job-1", "k") + assert result["status"] == "active" + assert result["total"] == 2 + assert result["current"] == 2 + assert result["time_consuming"] == 3.0 + assert result["data"] == [] + post_mock.assert_called_once() + + +def test_get_jinareader_status_completed_formats_processed_items(monkeypatch: pytest.MonkeyPatch) -> None: + status_payload = { + "data": { + "status": "completed", + "urls": ["u1"], + "processed": {"u1": {}}, + "failed": {}, + "duration": 1000, + } + } + processed_payload = { + "data": { + "processed": { + "u1": { + "data": { + "title": "t", + "url": "u1", + "description": "d", + "content": "md", + } + } + } + } + } + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + result = WebsiteService._get_jinareader_status("job-1", "k") + assert result["status"] == "completed" + assert result["data"] == [{"title": "t", "source_url": "u1", "description": "d", "markdown": "md"}] + assert post_mock.call_count == 2 + + +def test_get_crawl_url_data_dispatches_invalid_provider() -> None: + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_url_data("job-1", "bad", "https://example.com", "tenant-1") + + +def test_get_crawl_url_data_hits_invalid_provider_branch_when_credentials_stubbed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_url_data("job-1", object(), "u", "tenant-1") # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("provider", "method_name"), + [ + ("firecrawl", "_get_firecrawl_url_data"), + ("watercrawl", "_get_watercrawl_url_data"), + ("jinareader", "_get_jinareader_url_data"), + ], +) +def test_get_crawl_url_data_dispatches(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + impl_mock = MagicMock(return_value={"ok": True}) + monkeypatch.setattr(WebsiteService, method_name, impl_mock) + + result = WebsiteService.get_crawl_url_data("job-1", provider, "u", "tenant-1") + assert result == {"ok": True} + impl_mock.assert_called_once() + + +def test_get_firecrawl_url_data_reads_from_storage_when_present(monkeypatch: pytest.MonkeyPatch) -> None: + stored_list = [{"source_url": "https://example.com", "title": "t"}] + stored = json.dumps(stored_list).encode("utf-8") + + storage_mock = MagicMock() + storage_mock.exists.return_value = True + storage_mock.load_once.return_value = stored + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock()) + + result = WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"}) + assert result == {"source_url": "https://example.com", "title": "t"} + assert result is not stored_list[0] + + +def test_get_firecrawl_url_data_returns_none_when_storage_empty(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = True + storage_mock.load_once.return_value = b"" + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {}) is None + + +def test_get_firecrawl_url_data_raises_when_job_not_completed(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = False + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "active"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + with pytest.raises(ValueError, match="Crawl job is not completed"): + WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": None}) + + +def test_get_firecrawl_url_data_returns_none_when_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = False + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "data": [{"source_url": "x"}]} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"}) is None + + +def test_get_watercrawl_url_data_delegates(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.get_crawl_url_data.return_value = {"source_url": "u"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + result = WebsiteService._get_watercrawl_url_data("job-1", "u", "k", {"base_url": "b"}) + assert result == {"source_url": "u"} + provider_instance.get_crawl_url_data.assert_called_once_with("job-1", "u") + + +def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module.httpx, + "get", + MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"url": "u"}})), + ) + assert WebsiteService._get_jinareader_url_data("", "u", "k") == {"url": "u"} + + +def test_get_jinareader_url_data_without_job_id_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + with pytest.raises(ValueError, match="Failed to crawl$"): + WebsiteService._get_jinareader_url_data("", "u", "k") + + +def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(monkeypatch: pytest.MonkeyPatch) -> None: + status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}} + processed_payload = {"data": {"processed": {"u1": {"data": {"url": "u", "title": "t"}}}}} + + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") == {"url": "u", "title": "t"} + assert post_mock.call_count == 2 + + +def test_get_jinareader_url_data_with_job_id_not_completed_raises(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock(return_value=_DummyHttpxResponse({"data": {"status": "active"}})) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + with pytest.raises(ValueError, match=r"Crawl job is no\s*t completed"): + WebsiteService._get_jinareader_url_data("job-1", "u", "k") + + +def test_get_jinareader_url_data_with_job_id_completed_but_not_found_returns_none( + monkeypatch: pytest.MonkeyPatch, +) -> None: + status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}} + processed_payload = {"data": {"processed": {"u1": {"data": {"url": "other"}}}}} + + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") is None + + +def test_get_scrape_url_data_dispatches_and_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + + scrape_mock = MagicMock(return_value={"data": "x"}) + monkeypatch.setattr(WebsiteService, "_scrape_with_firecrawl", scrape_mock) + assert WebsiteService.get_scrape_url_data("firecrawl", "u", "tenant-1", True) == {"data": "x"} + scrape_mock.assert_called_once() + + watercrawl_mock = MagicMock(return_value={"data": "y"}) + monkeypatch.setattr(WebsiteService, "_scrape_with_watercrawl", watercrawl_mock) + assert WebsiteService.get_scrape_url_data("watercrawl", "u", "tenant-1", False) == {"data": "y"} + watercrawl_mock.assert_called_once() + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_scrape_url_data("jinareader", "u", "tenant-1", True) + + +def test_scrape_with_firecrawl_calls_app(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.scrape_url.return_value = {"markdown": "m"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + result = WebsiteService._scrape_with_firecrawl( + request=website_service_module.ScrapeRequest( + provider="firecrawl", + url="u", + tenant_id="tenant-1", + only_main_content=True, + ), + api_key="k", + config={"base_url": "b"}, + ) + assert result == {"markdown": "m"} + firecrawl_instance.scrape_url.assert_called_once_with(url="u", params={"onlyMainContent": True}) + + +def test_scrape_with_watercrawl_calls_provider(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.scrape_url.return_value = {"markdown": "m"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + result = WebsiteService._scrape_with_watercrawl( + request=website_service_module.ScrapeRequest( + provider="watercrawl", + url="u", + tenant_id="tenant-1", + only_main_content=False, + ), + api_key="k", + config={"base_url": "b"}, + ) + assert result == {"markdown": "m"} + provider_instance.scrape_url.assert_called_once_with("u") From dc50e4c4f29acff2f1cc6fa5b87433d9cb5833fd Mon Sep 17 00:00:00 2001 From: Rajat Agarwal Date: Thu, 12 Mar 2026 13:05:25 +0530 Subject: [PATCH 2/3] test: added test cases for core.workflow module (#33126) --- .../apps/advanced_chat/test_app_generator.py | 15 +- .../test_generate_task_pipeline_core.py | 8 +- .../apps/workflow/test_app_generator_extra.py | 5 +- .../core/workflow/test_node_factory.py | 655 +++++++++++++++-- .../workflow/test_workflow_entry_helpers.py | 656 ++++++++++++++++++ .../workflow/test_workflow_service.py | 8 +- 6 files changed, 1266 insertions(+), 81 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py index 8faae3661d..e2618d960c 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -150,8 +150,9 @@ class TestAdvancedChatAppGeneratorInternals: "_DummyTraceQueueManager", (TraceQueueManager,), { - "__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id) - or setattr(self, "user_id", user_id) + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) }, ) monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.TraceQueueManager", DummyTraceQueueManager) @@ -1124,8 +1125,9 @@ class TestAdvancedChatAppGeneratorInternals: "_DummyTraceQueueManager", (TraceQueueManager,), { - "__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id) - or setattr(self, "user_id", user_id) + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) }, ) monkeypatch.setattr( @@ -1202,8 +1204,9 @@ class TestAdvancedChatAppGeneratorInternals: "_DummyTraceQueueManager", (TraceQueueManager,), { - "__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id) - or setattr(self, "user_id", user_id) + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) }, ) monkeypatch.setattr( diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index b348ffc33b..67f87710a1 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -240,12 +240,12 @@ class TestAdvancedChatGenerateTaskPipeline: def test_iteration_and_loop_handlers(self): pipeline = _make_pipeline() pipeline._workflow_run_id = "run-id" - pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = ( - lambda **kwargs: "iter_start" + pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: ( + "iter_start" ) pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next" - pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = ( - lambda **kwargs: "iter_done" + pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: ( + "iter_done" ) pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start" pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next" diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py index 6d6f9272cb..09ad078a70 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py @@ -144,8 +144,9 @@ class TestWorkflowAppGeneratorGenerate: "_DummyTraceQueueManager", (TraceQueueManager,), { - "__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id) - or setattr(self, "user_id", user_id) + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) }, ) monkeypatch.setattr( diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 22be656d4b..4a5f561c22 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -1,82 +1,603 @@ -from __future__ import annotations +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, sentinel -from typing import Any +import pytest -from core.model_manager import ModelInstance -from core.workflow.node_factory import DifyNodeFactory -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom +from core.workflow import node_factory +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import NodeType, SystemVariableKey +from dify_graph.nodes.code.entities import CodeLanguage +from dify_graph.variables.segments import StringSegment -def _build_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, +def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None: + assert config["id"] == node_id + assert isinstance(config["data"], BaseNodeData) + assert config["data"].type == node_type + assert config["data"].version == version + + +class TestFetchMemory: + @pytest.mark.parametrize( + ("conversation_id", "memory_config"), + [ + (None, object()), + ("conversation-id", None), + ], ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=[], - ), - start_at=0.0, + def test_returns_none_when_memory_or_conversation_is_missing(self, conversation_id, memory_config): + result = node_factory.fetch_memory( + conversation_id=conversation_id, + app_id="app-id", + node_data_memory=memory_config, + model_instance=sentinel.model_instance, + ) + + assert result is None + + def test_returns_none_when_conversation_does_not_exist(self, monkeypatch): + class FakeSelect: + def where(self, *_args): + return self + + class FakeSession: + def __init__(self, *_args, **_kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *_args): + return False + + def scalar(self, _stmt): + return None + + monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine)) + monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect())) + monkeypatch.setattr(node_factory, "Session", FakeSession) + + result = node_factory.fetch_memory( + conversation_id="conversation-id", + app_id="app-id", + node_data_memory=object(), + model_instance=sentinel.model_instance, + ) + + assert result is None + + def test_builds_token_buffer_memory_for_existing_conversation(self, monkeypatch): + conversation = sentinel.conversation + memory = sentinel.memory + + class FakeSelect: + def where(self, *_args): + return self + + class FakeSession: + def __init__(self, *_args, **_kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *_args): + return False + + def scalar(self, _stmt): + return conversation + + token_buffer_memory = MagicMock(return_value=memory) + monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine)) + monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect())) + monkeypatch.setattr(node_factory, "Session", FakeSession) + monkeypatch.setattr(node_factory, "TokenBufferMemory", token_buffer_memory) + + result = node_factory.fetch_memory( + conversation_id="conversation-id", + app_id="app-id", + node_data_memory=object(), + model_instance=sentinel.model_instance, + ) + + assert result is memory + token_buffer_memory.assert_called_once_with( + conversation=conversation, + model_instance=sentinel.model_instance, + ) + + +class TestDefaultWorkflowCodeExecutor: + def test_execute_delegates_to_code_executor(self, monkeypatch): + executor = node_factory.DefaultWorkflowCodeExecutor() + execute_workflow_code_template = MagicMock(return_value={"answer": "ok"}) + monkeypatch.setattr( + node_factory.CodeExecutor, + "execute_workflow_code_template", + execute_workflow_code_template, + ) + + result = executor.execute( + language=CodeLanguage.PYTHON3, + code="print('ok')", + inputs={"name": "workflow"}, + ) + + assert result == {"answer": "ok"} + execute_workflow_code_template.assert_called_once_with( + language=CodeLanguage.PYTHON3, + code="print('ok')", + inputs={"name": "workflow"}, + ) + + def test_is_execution_error_checks_code_execution_error_type(self): + executor = node_factory.DefaultWorkflowCodeExecutor() + + assert executor.is_execution_error(node_factory.CodeExecutionError("boom")) is True + assert executor.is_execution_error(RuntimeError("boom")) is False + + +class TestDifyNodeFactoryInit: + def test_init_builds_default_dependencies(self): + graph_init_params = SimpleNamespace(run_context={"context": "value"}) + graph_runtime_state = sentinel.graph_runtime_state + dify_context = SimpleNamespace(tenant_id="tenant-id") + template_renderer = sentinel.template_renderer + rag_retrieval = sentinel.rag_retrieval + unstructured_api_config = sentinel.unstructured_api_config + http_request_config = sentinel.http_request_config + credentials_provider = sentinel.credentials_provider + model_factory = sentinel.model_factory + + with ( + patch.object( + node_factory.DifyNodeFactory, + "_resolve_dify_context", + return_value=dify_context, + ) as resolve_dify_context, + patch.object( + node_factory, + "CodeExecutorJinja2TemplateRenderer", + return_value=template_renderer, + ) as renderer_factory, + patch.object(node_factory, "DatasetRetrieval", return_value=rag_retrieval), + patch.object( + node_factory, + "UnstructuredApiConfig", + return_value=unstructured_api_config, + ), + patch.object( + node_factory, + "build_http_request_config", + return_value=http_request_config, + ), + patch.object( + node_factory, + "build_dify_model_access", + return_value=(credentials_provider, model_factory), + ) as build_dify_model_access, + ): + factory = node_factory.DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + resolve_dify_context.assert_called_once_with(graph_init_params.run_context) + build_dify_model_access.assert_called_once_with("tenant-id") + renderer_factory.assert_called_once() + assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor + assert factory.graph_init_params is graph_init_params + assert factory.graph_runtime_state is graph_runtime_state + assert factory._dify_context is dify_context + assert factory._template_renderer is template_renderer + assert factory._rag_retrieval is rag_retrieval + assert factory._document_extractor_unstructured_api_config is unstructured_api_config + assert factory._http_request_config is http_request_config + assert factory._llm_credentials_provider is credentials_provider + assert factory._llm_model_factory is model_factory + + +class TestDifyNodeFactoryResolveContext: + def test_requires_reserved_context_key(self): + with pytest.raises(ValueError, match=DIFY_RUN_CONTEXT_KEY): + node_factory.DifyNodeFactory._resolve_dify_context({}) + + def test_returns_existing_dify_context(self): + dify_context = DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + result = node_factory.DifyNodeFactory._resolve_dify_context({DIFY_RUN_CONTEXT_KEY: dify_context}) + + assert result is dify_context + + def test_validates_mapping_context(self): + raw_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant-id", + "app_id": "app-id", + "user_id": "user-id", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + } + + result = node_factory.DifyNodeFactory._resolve_dify_context(raw_context) + + assert isinstance(result, DifyRunContext) + assert result.tenant_id == "tenant-id" + + +class TestDifyNodeFactoryCreateNode: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory.graph_init_params = sentinel.graph_init_params + factory.graph_runtime_state = sentinel.graph_runtime_state + factory._dify_context = SimpleNamespace(tenant_id="tenant-id", app_id="app-id") + factory._code_executor = sentinel.code_executor + factory._code_limits = sentinel.code_limits + factory._template_renderer = sentinel.template_renderer + factory._template_transform_max_output_length = 2048 + factory._http_request_http_client = sentinel.http_client + factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory + factory._http_request_file_manager = sentinel.file_manager + factory._rag_retrieval = sentinel.rag_retrieval + factory._document_extractor_unstructured_api_config = sentinel.unstructured_api_config + factory._http_request_config = sentinel.http_request_config + factory._llm_credentials_provider = sentinel.credentials_provider + factory._llm_model_factory = sentinel.model_factory + return factory + + def test_rejects_unknown_node_type(self, factory): + with pytest.raises(ValueError, match="Input should be"): + factory.create_node({"id": "node-id", "data": {"type": "missing"}}) + + def test_rejects_missing_class_mapping(self, monkeypatch, factory): + monkeypatch.setattr(node_factory, "NODE_TYPE_CLASSES_MAPPING", {}) + + with pytest.raises(ValueError, match="No class mapping found for node type: start"): + factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}}) + + def test_rejects_missing_latest_class(self, monkeypatch, factory): + monkeypatch.setattr( + node_factory, + "NODE_TYPE_CLASSES_MAPPING", + {NodeType.START: {node_factory.LATEST_VERSION: None}}, + ) + + with pytest.raises(ValueError, match="No latest version class found for node type: start"): + factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}}) + + def test_uses_version_specific_class_when_available(self, monkeypatch, factory): + matched_node = sentinel.matched_node + latest_node_class = MagicMock(return_value=sentinel.latest_node) + matched_node_class = MagicMock(return_value=matched_node) + monkeypatch.setattr( + node_factory, + "NODE_TYPE_CLASSES_MAPPING", + { + NodeType.START: { + node_factory.LATEST_VERSION: latest_node_class, + "9": matched_node_class, + } + }, + ) + + result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}}) + + assert result is matched_node + matched_node_class.assert_called_once() + kwargs = matched_node_class.call_args.kwargs + assert kwargs["id"] == "node-id" + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9") + assert kwargs["graph_init_params"] is sentinel.graph_init_params + assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + latest_node_class.assert_not_called() + + def test_falls_back_to_latest_class_when_version_specific_mapping_is_missing(self, monkeypatch, factory): + latest_node = sentinel.latest_node + latest_node_class = MagicMock(return_value=latest_node) + monkeypatch.setattr( + node_factory, + "NODE_TYPE_CLASSES_MAPPING", + {NodeType.START: {node_factory.LATEST_VERSION: latest_node_class}}, + ) + + result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}}) + + assert result is latest_node + latest_node_class.assert_called_once() + kwargs = latest_node_class.call_args.kwargs + assert kwargs["id"] == "node-id" + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9") + assert kwargs["graph_init_params"] is sentinel.graph_init_params + assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + + @pytest.mark.parametrize( + ("node_type", "constructor_name"), + [ + (NodeType.CODE, "CodeNode"), + (NodeType.TEMPLATE_TRANSFORM, "TemplateTransformNode"), + (NodeType.HTTP_REQUEST, "HttpRequestNode"), + (NodeType.HUMAN_INPUT, "HumanInputNode"), + (NodeType.KNOWLEDGE_INDEX, "KnowledgeIndexNode"), + (NodeType.DATASOURCE, "DatasourceNode"), + (NodeType.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"), + (NodeType.DOCUMENT_EXTRACTOR, "DocumentExtractorNode"), + ], ) - return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + def test_creates_specialized_nodes(self, monkeypatch, factory, node_type, constructor_name): + created_node = object() + constructor = MagicMock(name=constructor_name, return_value=created_node) + monkeypatch.setattr( + node_factory, + "NODE_TYPE_CLASSES_MAPPING", + {node_type: {node_factory.LATEST_VERSION: constructor}}, + ) + + if constructor_name == "HumanInputNode": + form_repository = sentinel.form_repository + form_repository_impl = MagicMock(return_value=form_repository) + monkeypatch.setattr( + node_factory, + "HumanInputFormRepositoryImpl", + form_repository_impl, + ) + elif constructor_name == "KnowledgeIndexNode": + index_processor = sentinel.index_processor + summary_index = sentinel.summary_index + monkeypatch.setattr(node_factory, "IndexProcessor", MagicMock(return_value=index_processor)) + monkeypatch.setattr(node_factory, "SummaryIndex", MagicMock(return_value=summary_index)) + + node_config = {"id": "node-id", "data": {"type": node_type.value}} + result = factory.create_node(node_config) + + assert result is created_node + kwargs = constructor.call_args.kwargs + assert kwargs["id"] == "node-id" + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type) + assert kwargs["graph_init_params"] is sentinel.graph_init_params + assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + + if constructor_name == "CodeNode": + assert kwargs["code_executor"] is sentinel.code_executor + assert kwargs["code_limits"] is sentinel.code_limits + elif constructor_name == "TemplateTransformNode": + assert kwargs["template_renderer"] is sentinel.template_renderer + assert kwargs["max_output_length"] == 2048 + elif constructor_name == "HttpRequestNode": + assert kwargs["http_request_config"] is sentinel.http_request_config + assert kwargs["http_client"] is sentinel.http_client + assert kwargs["tool_file_manager_factory"] is sentinel.tool_file_manager_factory + assert kwargs["file_manager"] is sentinel.file_manager + elif constructor_name == "HumanInputNode": + assert kwargs["form_repository"] is form_repository + form_repository_impl.assert_called_once_with(tenant_id="tenant-id") + elif constructor_name == "KnowledgeIndexNode": + assert kwargs["index_processor"] is index_processor + assert kwargs["summary_index_service"] is summary_index + elif constructor_name == "DatasourceNode": + assert kwargs["datasource_manager"] is node_factory.DatasourceManager + elif constructor_name == "KnowledgeRetrievalNode": + assert kwargs["rag_retrieval"] is sentinel.rag_retrieval + elif constructor_name == "DocumentExtractorNode": + assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config + assert kwargs["http_client"] is sentinel.http_client + + @pytest.mark.parametrize( + ("node_type", "constructor_name", "expected_extra_kwargs"), + [ + (NodeType.LLM, "LLMNode", {"http_client": sentinel.http_client}), + (NodeType.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}), + (NodeType.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}), + ], + ) + def test_creates_model_backed_nodes( + self, + monkeypatch, + factory, + node_type, + constructor_name, + expected_extra_kwargs, + ): + created_node = object() + constructor = MagicMock(name=constructor_name, return_value=created_node) + monkeypatch.setattr( + node_factory, + "NODE_TYPE_CLASSES_MAPPING", + {node_type: {node_factory.LATEST_VERSION: constructor}}, + ) + llm_init_kwargs = { + "credentials_provider": sentinel.credentials_provider, + "model_factory": sentinel.model_factory, + "model_instance": sentinel.model_instance, + "memory": sentinel.memory, + **expected_extra_kwargs, + } + build_llm_init_kwargs = MagicMock(return_value=llm_init_kwargs) + factory._build_llm_compatible_node_init_kwargs = build_llm_init_kwargs + + node_config = {"id": "node-id", "data": {"type": node_type.value}} + result = factory.create_node(node_config) + + assert result is created_node + build_llm_init_kwargs.assert_called_once() + helper_kwargs = build_llm_init_kwargs.call_args.kwargs + assert helper_kwargs["node_class"] is constructor + assert isinstance(helper_kwargs["node_data"], BaseNodeData) + assert helper_kwargs["node_data"].type == node_type + assert helper_kwargs["include_http_client"] is (node_type != NodeType.PARAMETER_EXTRACTOR) + + constructor_kwargs = constructor.call_args.kwargs + assert constructor_kwargs["id"] == "node-id" + _assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type) + assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params + assert constructor_kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider + assert constructor_kwargs["model_factory"] is sentinel.model_factory + assert constructor_kwargs["model_instance"] is sentinel.model_instance + assert constructor_kwargs["memory"] is sentinel.memory + for key, value in expected_extra_kwargs.items(): + assert constructor_kwargs[key] is value -def test_create_node_uses_declared_node_data_type_for_llm_validation(monkeypatch): - class _FactoryLLMNodeData(LLMNodeData): - pass +class TestDifyNodeFactoryModelInstance: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory._llm_credentials_provider = MagicMock() + factory._llm_model_factory = MagicMock() + return factory - llm_node_config = { - "id": "llm-node", - "data": { - "type": "llm", - "title": "LLM", - "model": { - "provider": "openai", - "name": "gpt-4o-mini", - "mode": "chat", - "completion_params": {}, - }, - "prompt_template": [], - "context": { - "enabled": False, - }, - }, - } - graph_config = {"nodes": [llm_node_config], "edges": []} - factory = _build_factory(graph_config) - captured: dict[str, object] = {} + @pytest.fixture + def llm_model_setup(self, factory): + def _configure( + *, + completion_params=None, + has_provider_model=True, + model_schema=sentinel.model_schema, + ): + credentials = {"api_key": "secret"} + node_data_model = SimpleNamespace( + provider="provider", + name="model", + mode="chat", + completion_params=completion_params or {}, + ) + node_data = SimpleNamespace(model=node_data_model) + provider_model = MagicMock() if has_provider_model else None + provider_model_bundle = SimpleNamespace( + configuration=SimpleNamespace(get_provider_model=MagicMock(return_value=provider_model)) + ) + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = model_schema + model_instance = SimpleNamespace( + provider_model_bundle=provider_model_bundle, + model_type_instance=model_type_instance, + provider=None, + model_name=None, + credentials=None, + parameters=None, + stop=None, + ) + factory._llm_credentials_provider.fetch.return_value = credentials + factory._llm_model_factory.init_model_instance.return_value = model_instance + return SimpleNamespace( + node_data=node_data, + credentials=credentials, + provider_model=provider_model, + model_type_instance=model_type_instance, + model_instance=model_instance, + ) - monkeypatch.setattr(LLMNode, "_node_data_type", _FactoryLLMNodeData) + return _configure - def _capture_model_instance(self: DifyNodeFactory, node_data: object) -> ModelInstance: - captured["node_data"] = node_data - return object() # type: ignore[return-value] + def test_requires_llm_mode(self, factory): + node_data = SimpleNamespace( + model=SimpleNamespace( + provider="provider", + name="model", + mode="", + completion_params={}, + ) + ) - def _capture_memory( - self: DifyNodeFactory, - *, - node_data: object, - model_instance: ModelInstance, - ) -> None: - captured["memory_node_data"] = node_data + with pytest.raises(node_factory.LLMModeRequiredError, match="LLM mode is required"): + factory._build_model_instance_for_llm_node(node_data) - monkeypatch.setattr(DifyNodeFactory, "_build_model_instance_for_llm_node", _capture_model_instance) - monkeypatch.setattr(DifyNodeFactory, "_build_memory_for_llm_node", _capture_memory) + def test_raises_when_provider_model_is_missing(self, factory, llm_model_setup): + setup = llm_model_setup(has_provider_model=False) - node = factory.create_node(llm_node_config) + with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"): + factory._build_model_instance_for_llm_node(setup.node_data) - assert isinstance(captured["node_data"], _FactoryLLMNodeData) - assert isinstance(captured["memory_node_data"], _FactoryLLMNodeData) - assert isinstance(node.node_data, _FactoryLLMNodeData) + def test_raises_when_model_schema_is_missing(self, factory, llm_model_setup): + setup = llm_model_setup(model_schema=None) + + with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"): + factory._build_model_instance_for_llm_node(setup.node_data) + + setup.provider_model.raise_for_status.assert_called_once() + + def test_builds_model_instance_and_normalizes_stop_tokens(self, factory, llm_model_setup): + setup = llm_model_setup( + completion_params={"temperature": 0.3, "stop": "not-a-list"}, + model_schema={"schema": "value"}, + ) + + result = factory._build_model_instance_for_llm_node(setup.node_data) + + assert result is setup.model_instance + assert result.provider == "provider" + assert result.model_name == "model" + assert result.credentials == setup.credentials + assert result.parameters == {"temperature": 0.3} + assert result.stop == () + assert result.model_type_instance is setup.model_type_instance + setup.provider_model.raise_for_status.assert_called_once() + + +class TestDifyNodeFactoryMemory: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory._dify_context = SimpleNamespace(app_id="app-id") + factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock()) + return factory + + def test_returns_none_when_memory_is_not_configured(self, factory): + result = factory._build_memory_for_llm_node( + node_data=SimpleNamespace(memory=None), + model_instance=sentinel.model_instance, + ) + + assert result is None + factory.graph_runtime_state.variable_pool.get.assert_not_called() + + def test_uses_string_segment_conversation_id(self, monkeypatch, factory): + memory_config = sentinel.memory_config + factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="conversation-id") + fetch_memory = MagicMock(return_value=sentinel.memory) + monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory) + + result = factory._build_memory_for_llm_node( + node_data=SimpleNamespace(memory=memory_config), + model_instance=sentinel.model_instance, + ) + + assert result is sentinel.memory + factory.graph_runtime_state.variable_pool.get.assert_called_once_with( + ["sys", SystemVariableKey.CONVERSATION_ID] + ) + fetch_memory.assert_called_once_with( + conversation_id="conversation-id", + app_id="app-id", + node_data_memory=memory_config, + model_instance=sentinel.model_instance, + ) + + def test_ignores_non_string_segment_conversation_ids(self, monkeypatch, factory): + memory_config = sentinel.memory_config + factory.graph_runtime_state.variable_pool.get.return_value = sentinel.segment + fetch_memory = MagicMock(return_value=sentinel.memory) + monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory) + + result = factory._build_memory_for_llm_node( + node_data=SimpleNamespace(memory=memory_config), + model_instance=sentinel.model_instance, + ) + + assert result is sentinel.memory + fetch_memory.assert_called_once_with( + conversation_id=None, + app_id="app-id", + node_data_memory=memory_config, + model_instance=sentinel.model_instance, + ) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py new file mode 100644 index 0000000000..fe211fb76a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -0,0 +1,656 @@ +from collections import UserString +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, sentinel + +import pytest + +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow import workflow_entry +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.graph_events import GraphRunFailedEvent +from dify_graph.nodes import NodeType +from dify_graph.runtime import ChildGraphNotFoundError + + +def _build_typed_node_config(node_type: NodeType): + return NodeConfigDictAdapter.validate_python({"id": "node-id", "data": {"type": node_type}}) + + +class TestWorkflowChildEngineBuilder: + @pytest.mark.parametrize( + ("graph_config", "node_id", "expected"), + [ + ({"nodes": [{"id": "root"}]}, "root", True), + ({"nodes": [{"id": "root"}]}, "other", False), + ({"nodes": "invalid"}, "root", None), + ({"nodes": ["invalid"]}, "root", None), + ], + ) + def test_has_node_id(self, graph_config, node_id, expected): + result = workflow_entry._WorkflowChildEngineBuilder._has_node_id(graph_config, node_id) + + assert result is expected + + def test_build_child_engine_raises_when_root_node_is_missing(self): + builder = workflow_entry._WorkflowChildEngineBuilder() + + with patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory): + with pytest.raises(ChildGraphNotFoundError, match="child graph root node 'missing' not found"): + builder.build_child_engine( + workflow_id="workflow-id", + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + graph_config={"nodes": []}, + root_node_id="missing", + ) + + def test_build_child_engine_constructs_graph_engine_and_layers(self): + builder = workflow_entry._WorkflowChildEngineBuilder() + child_graph = sentinel.child_graph + child_engine = MagicMock() + quota_layer = sentinel.quota_layer + additional_layers = [sentinel.layer_one, sentinel.layer_two] + + with ( + patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory) as dify_node_factory, + patch.object(workflow_entry.Graph, "init", return_value=child_graph) as graph_init, + patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls, + patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), + patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), + patch.object(workflow_entry, "LLMQuotaLayer", return_value=quota_layer), + ): + result = builder.build_child_engine( + workflow_id="workflow-id", + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + graph_config={"nodes": [{"id": "root"}]}, + root_node_id="root", + layers=additional_layers, + ) + + assert result is child_engine + dify_node_factory.assert_called_once_with( + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + ) + graph_init.assert_called_once_with( + graph_config={"nodes": [{"id": "root"}]}, + node_factory=sentinel.factory, + root_node_id="root", + ) + graph_engine_cls.assert_called_once_with( + workflow_id="workflow-id", + graph=child_graph, + graph_runtime_state=sentinel.graph_runtime_state, + command_channel=sentinel.command_channel, + config=sentinel.graph_engine_config, + child_engine_builder=builder, + ) + assert child_engine.layer.call_args_list == [ + ((quota_layer,), {}), + ((sentinel.layer_one,), {}), + ((sentinel.layer_two,), {}), + ] + + +class TestWorkflowEntryInit: + def test_rejects_call_depth_above_limit(self): + call_depth = workflow_entry.dify_config.WORKFLOW_CALL_MAX_DEPTH + 1 + + with pytest.raises(ValueError, match="Max workflow call depth"): + workflow_entry.WorkflowEntry( + tenant_id="tenant-id", + app_id="app-id", + workflow_id="workflow-id", + graph_config={"nodes": [], "edges": []}, + graph=sentinel.graph, + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=call_depth, + variable_pool=sentinel.variable_pool, + graph_runtime_state=sentinel.graph_runtime_state, + ) + + def test_applies_debug_and_observability_layers(self): + graph_engine = MagicMock() + debug_layer = sentinel.debug_layer + execution_limits_layer = sentinel.execution_limits_layer + llm_quota_layer = sentinel.llm_quota_layer + observability_layer = sentinel.observability_layer + + with ( + patch.object(workflow_entry.dify_config, "DEBUG", True), + patch.object(workflow_entry.dify_config, "ENABLE_OTEL", False), + patch.object(workflow_entry, "is_instrument_flag_enabled", return_value=True), + patch.object(workflow_entry, "GraphEngine", return_value=graph_engine) as graph_engine_cls, + patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), + patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), + patch.object(workflow_entry, "DebugLoggingLayer", return_value=debug_layer) as debug_logging_layer, + patch.object( + workflow_entry, + "ExecutionLimitsLayer", + return_value=execution_limits_layer, + ) as execution_limits_layer_cls, + patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer), + patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer), + ): + entry = workflow_entry.WorkflowEntry( + tenant_id="tenant-id", + app_id="app-id", + workflow_id="workflow-id-123456", + graph_config={"nodes": [], "edges": []}, + graph=sentinel.graph, + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + variable_pool=sentinel.variable_pool, + graph_runtime_state=sentinel.graph_runtime_state, + command_channel=None, + ) + + assert entry.command_channel is sentinel.command_channel + graph_engine_cls.assert_called_once_with( + workflow_id="workflow-id-123456", + graph=sentinel.graph, + graph_runtime_state=sentinel.graph_runtime_state, + command_channel=sentinel.command_channel, + config=sentinel.graph_engine_config, + child_engine_builder=entry._child_engine_builder, + ) + debug_logging_layer.assert_called_once_with( + level="DEBUG", + include_inputs=True, + include_outputs=True, + include_process_data=False, + logger_name="GraphEngine.Debug.workflow", + ) + execution_limits_layer_cls.assert_called_once_with( + max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME, + ) + assert graph_engine.layer.call_args_list == [ + ((debug_layer,), {}), + ((execution_limits_layer,), {}), + ((llm_quota_layer,), {}), + ((observability_layer,), {}), + ] + + +class TestWorkflowEntryRun: + def test_run_swallows_generate_task_stopped_errors(self): + entry = object.__new__(workflow_entry.WorkflowEntry) + entry.graph_engine = MagicMock() + entry.graph_engine.run.side_effect = GenerateTaskStoppedError() + + assert list(entry.run()) == [] + + def test_run_emits_failed_event_for_unexpected_errors(self): + entry = object.__new__(workflow_entry.WorkflowEntry) + entry.graph_engine = MagicMock() + entry.graph_engine.run.side_effect = RuntimeError("boom") + + events = list(entry.run()) + + assert len(events) == 1 + assert isinstance(events[0], GraphRunFailedEvent) + assert events[0].error == "boom" + + +class TestWorkflowEntrySingleStepRun: + def test_uses_empty_mapping_when_selector_extraction_is_not_implemented(self): + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "fake" + + @staticmethod + def version(): + return "1" + + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + raise NotImplementedError + + with ( + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + ) as mapping_user_inputs_to_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + return_value=iter(["event"]), + ), + ): + dify_node_factory.return_value.create_node.return_value = FakeNode() + workflow = SimpleNamespace( + tenant_id="tenant-id", + app_id="app-id", + id="workflow-id", + graph_dict={"nodes": [], "edges": []}, + get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START), + ) + + node, generator = workflow_entry.WorkflowEntry.single_step_run( + workflow=workflow, + node_id="node-id", + user_id="user-id", + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + ) + + assert node.id == "node-id" + assert list(generator) == ["event"] + load_into_variable_pool.assert_called_once_with( + variable_loader=workflow_entry.DUMMY_VARIABLE_LOADER, + variable_pool=sentinel.variable_pool, + variable_mapping={}, + user_inputs={"question": "hello"}, + ) + mapping_user_inputs_to_variable_pool.assert_called_once_with( + variable_mapping={}, + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + tenant_id="tenant-id", + ) + + def test_skips_user_input_mapping_for_datasource_nodes(self): + class FakeDatasourceNode: + id = "node-id" + node_type = "datasource" + + @staticmethod + def version(): + return "1" + + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + return {"question": ["node", "question"]} + + with ( + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + ) as mapping_user_inputs_to_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + return_value=iter(["event"]), + ), + ): + dify_node_factory.return_value.create_node.return_value = FakeDatasourceNode() + workflow = SimpleNamespace( + tenant_id="tenant-id", + app_id="app-id", + id="workflow-id", + graph_dict={"nodes": [], "edges": []}, + get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.DATASOURCE), + ) + + node, generator = workflow_entry.WorkflowEntry.single_step_run( + workflow=workflow, + node_id="node-id", + user_id="user-id", + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + ) + + assert node.id == "node-id" + assert list(generator) == ["event"] + load_into_variable_pool.assert_called_once() + mapping_user_inputs_to_variable_pool.assert_not_called() + + def test_wraps_traced_node_run_failures(self): + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "fake" + + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + return {} + + @staticmethod + def version(): + return "1" + + with ( + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "load_into_variable_pool"), + patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"), + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + side_effect=RuntimeError("boom"), + ), + ): + dify_node_factory.return_value.create_node.return_value = FakeNode() + workflow = SimpleNamespace( + tenant_id="tenant-id", + app_id="app-id", + id="workflow-id", + graph_dict={"nodes": [], "edges": []}, + get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START), + ) + + with pytest.raises(WorkflowNodeRunFailedError): + workflow_entry.WorkflowEntry.single_step_run( + workflow=workflow, + node_id="node-id", + user_id="user-id", + user_inputs={}, + variable_pool=sentinel.variable_pool, + ) + + +class TestWorkflowEntryHelpers: + def test_create_single_node_graph_builds_start_edge(self): + graph = workflow_entry.WorkflowEntry._create_single_node_graph( + node_id="target-node", + node_data={"type": NodeType.PARAMETER_EXTRACTOR}, + node_width=320, + node_height=180, + ) + + assert graph["nodes"][0]["id"] == "start" + assert graph["nodes"][1]["id"] == "target-node" + assert graph["nodes"][1]["width"] == 320 + assert graph["nodes"][1]["height"] == 180 + assert graph["edges"] == [ + { + "source": "start", + "target": "target-node", + "sourceHandle": "source", + "targetHandle": "target", + } + ] + + def test_run_free_node_rejects_unsupported_types(self): + with pytest.raises(ValueError, match="Node type start not supported"): + workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": NodeType.START.value}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={}, + ) + + def test_run_free_node_rejects_missing_node_class(self, monkeypatch): + monkeypatch.setattr( + workflow_entry, + "NODE_TYPE_CLASSES_MAPPING", + {NodeType.PARAMETER_EXTRACTOR: {"1": None}}, + ) + + with pytest.raises(ValueError, match="Node class not found for node type parameter-extractor"): + workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": NodeType.PARAMETER_EXTRACTOR.value}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={}, + ) + + def test_run_free_node_uses_empty_mapping_when_selector_extraction_is_not_implemented(self, monkeypatch): + class FakeNodeClass: + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + raise NotImplementedError + + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "parameter-extractor" + + @staticmethod + def version(): + return "1" + + dify_node_factory = MagicMock() + dify_node_factory.create_node.return_value = FakeNode() + monkeypatch.setattr( + workflow_entry, + "NODE_TYPE_CLASSES_MAPPING", + {NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}}, + ) + + with ( + patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables), + patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls, + patch.object( + workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params + ) as graph_init_params, + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object( + workflow_entry, "build_dify_run_context", return_value={"_dify": "context"} + ) as build_dify_run_context, + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls, + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + ) as mapping_user_inputs_to_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + return_value=iter(["event"]), + ), + ): + node, generator = workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={"question": "hello"}, + ) + + assert node.id == "node-id" + assert list(generator) == ["event"] + variable_pool_cls.assert_called_once_with( + system_variables=sentinel.system_variables, + user_inputs={}, + environment_variables=[], + ) + build_dify_run_context.assert_called_once_with( + tenant_id="tenant-id", + app_id="", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_params.assert_called_once_with( + workflow_id="", + graph_config=workflow_entry.WorkflowEntry._create_single_node_graph( + "node-id", {"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"} + ), + run_context={"_dify": "context"}, + call_depth=0, + ) + dify_node_factory_cls.assert_called_once_with( + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + ) + mapping_user_inputs_to_variable_pool.assert_called_once_with( + variable_mapping={}, + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + tenant_id="tenant-id", + ) + + def test_run_free_node_wraps_execution_failures(self, monkeypatch): + class FakeNodeClass: + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + return {} + + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "parameter-extractor" + + @staticmethod + def version(): + return "1" + + dify_node_factory = MagicMock() + dify_node_factory.create_node.return_value = FakeNode() + monkeypatch.setattr( + workflow_entry, + "NODE_TYPE_CLASSES_MAPPING", + {NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}}, + ) + + with ( + patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables), + patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool), + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory), + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + side_effect=RuntimeError("boom"), + ), + ): + with pytest.raises(WorkflowNodeRunFailedError, match="Node Title run failed: boom"): + workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={"question": "hello"}, + ) + + def test_handle_special_values_serializes_nested_files(self): + file = File( + tenant_id="tenant-id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.png", + filename="image.png", + extension=".png", + ) + + result = workflow_entry.WorkflowEntry.handle_special_values({"file": file, "nested": {"files": [file]}}) + + assert result == { + "file": file.to_dict(), + "nested": {"files": [file.to_dict()]}, + } + + def test_handle_special_values_returns_none_for_none(self): + assert workflow_entry.WorkflowEntry._handle_special_values(None) is None + + def test_handle_special_values_returns_scalar_as_is(self): + assert workflow_entry.WorkflowEntry._handle_special_values("plain-text") == "plain-text" + + +class TestMappingUserInputsBranches: + def test_rejects_invalid_node_variable_key(self): + class EmptySplitKey(UserString): + def split(self, _sep=None): + return [] + + with pytest.raises(ValueError, match="Invalid node variable broken"): + workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping={EmptySplitKey("broken"): ["node", "input"]}, + user_inputs={}, + variable_pool=MagicMock(), + tenant_id="tenant-id", + ) + + def test_skips_none_user_input_when_variable_already_exists(self): + variable_pool = MagicMock() + variable_pool.get.return_value = None + + workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping={"node.input": ["target", "input"]}, + user_inputs={"node.input": None}, + variable_pool=variable_pool, + tenant_id="tenant-id", + ) + + variable_pool.add.assert_not_called() + + def test_merges_structured_output_values(self): + variable_pool = MagicMock() + variable_pool.get.side_effect = [ + None, + SimpleNamespace(value={"existing": "value"}), + ] + + workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping={"node.answer": ["target", "structured_output", "answer"]}, + user_inputs={"node.answer": "new-value"}, + variable_pool=variable_pool, + tenant_id="tenant-id", + ) + + variable_pool.add.assert_called_once_with( + ["target", "structured_output"], + {"existing": "value", "answer": "new-value"}, + ) + + +class TestWorkflowEntryTracing: + def test_traced_node_run_reports_success(self): + layer = MagicMock() + + class FakeNode: + def ensure_execution_id(self): + return None + + def run(self): + yield "event" + + with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer): + events = list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode())) + + assert events == ["event"] + layer.on_graph_start.assert_called_once_with() + layer.on_node_run_start.assert_called_once() + layer.on_node_run_end.assert_called_once_with( + layer.on_node_run_start.call_args.args[0], + None, + ) + + def test_traced_node_run_reports_errors(self): + layer = MagicMock() + + class FakeNode: + def ensure_execution_id(self): + return None + + def run(self): + raise RuntimeError("boom") + yield + + with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer): + with pytest.raises(RuntimeError, match="boom"): + list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode())) + + assert isinstance(layer.on_node_run_end.call_args.args[1], RuntimeError) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 3953248c47..9ee8f88e71 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -311,7 +311,9 @@ class TestWorkflowService: mock_workflow.conversation_variables = [] # Mock node config - mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}} + mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": NodeType.LLM.value}} + ) mock_workflow.get_enclosing_node_type_and_id.return_value = None # Mock class methods @@ -376,7 +378,9 @@ class TestWorkflowService: mock_workflow.tenant_id = "tenant-1" mock_workflow.environment_variables = [] mock_workflow.conversation_variables = [] - mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}} + mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": NodeType.LLM.value}} + ) mock_workflow.get_enclosing_node_type_and_id.return_value = None monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) From ceaa399351009401721582c4f7e1f7f2bd682cbc Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Thu, 12 Mar 2026 15:50:07 +0800 Subject: [PATCH 3/3] test: refactor mock implementation in markdown component tests (#33350) Co-authored-by: CodingOnStar --- .../base/markdown/__tests__/index.spec.tsx | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/web/app/components/base/markdown/__tests__/index.spec.tsx b/web/app/components/base/markdown/__tests__/index.spec.tsx index 4d98fcba2e..5d0261b074 100644 --- a/web/app/components/base/markdown/__tests__/index.spec.tsx +++ b/web/app/components/base/markdown/__tests__/index.spec.tsx @@ -7,17 +7,15 @@ const { mockReactMarkdownWrapper } = vi.hoisted(() => ({ mockReactMarkdownWrapper: vi.fn(), })) -vi.mock('../react-markdown-wrapper', () => ({ - ReactMarkdownWrapper: () => null, -})) - vi.mock('next/dynamic', () => ({ - default: (loader: () => Promise) => { - void loader() - return (props: { latexContent: string }) => { + default: () => { + const MockStreamdownWrapper = (props: { latexContent: string }) => { mockReactMarkdownWrapper(props) return
{props.latexContent}
} + + MockStreamdownWrapper.displayName = 'MockStreamdownWrapper' + return MockStreamdownWrapper }, }))