mirror of
https://github.com/langgenius/dify.git
synced 2026-03-25 00:07:56 +08:00
Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing
This commit is contained in:
@ -150,8 +150,9 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.TraceQueueManager", DummyTraceQueueManager)
|
||||
@ -1124,8 +1125,9 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
@ -1202,8 +1204,9 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
|
||||
@ -240,12 +240,12 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
def test_iteration_and_loop_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = (
|
||||
lambda **kwargs: "iter_start"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: (
|
||||
"iter_start"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = (
|
||||
lambda **kwargs: "iter_done"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: (
|
||||
"iter_done"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start"
|
||||
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
|
||||
|
||||
@ -144,8 +144,9 @@ class TestWorkflowAppGeneratorGenerate:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
|
||||
@ -1,82 +1,603 @@
|
||||
from __future__ import annotations
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, sentinel
|
||||
|
||||
from typing import Any
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
|
||||
from core.workflow import node_factory
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import NodeType, SystemVariableKey
|
||||
from dify_graph.nodes.code.entities import CodeLanguage
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
|
||||
|
||||
def _build_factory(graph_config: dict[str, Any]) -> DifyNodeFactory:
|
||||
graph_init_params = build_test_graph_init_params(
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
|
||||
assert config["id"] == node_id
|
||||
assert isinstance(config["data"], BaseNodeData)
|
||||
assert config["data"].type == node_type
|
||||
assert config["data"].version == version
|
||||
|
||||
|
||||
class TestFetchMemory:
|
||||
@pytest.mark.parametrize(
|
||||
("conversation_id", "memory_config"),
|
||||
[
|
||||
(None, object()),
|
||||
("conversation-id", None),
|
||||
],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
),
|
||||
start_at=0.0,
|
||||
def test_returns_none_when_memory_or_conversation_is_missing(self, conversation_id, memory_config):
|
||||
result = node_factory.fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id="app-id",
|
||||
node_data_memory=memory_config,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_conversation_does_not_exist(self, monkeypatch):
|
||||
class FakeSelect:
|
||||
def where(self, *_args):
|
||||
return self
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args):
|
||||
return False
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
|
||||
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
|
||||
monkeypatch.setattr(node_factory, "Session", FakeSession)
|
||||
|
||||
result = node_factory.fetch_memory(
|
||||
conversation_id="conversation-id",
|
||||
app_id="app-id",
|
||||
node_data_memory=object(),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_builds_token_buffer_memory_for_existing_conversation(self, monkeypatch):
|
||||
conversation = sentinel.conversation
|
||||
memory = sentinel.memory
|
||||
|
||||
class FakeSelect:
|
||||
def where(self, *_args):
|
||||
return self
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args):
|
||||
return False
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return conversation
|
||||
|
||||
token_buffer_memory = MagicMock(return_value=memory)
|
||||
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
|
||||
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
|
||||
monkeypatch.setattr(node_factory, "Session", FakeSession)
|
||||
monkeypatch.setattr(node_factory, "TokenBufferMemory", token_buffer_memory)
|
||||
|
||||
result = node_factory.fetch_memory(
|
||||
conversation_id="conversation-id",
|
||||
app_id="app-id",
|
||||
node_data_memory=object(),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is memory
|
||||
token_buffer_memory.assert_called_once_with(
|
||||
conversation=conversation,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
|
||||
class TestDefaultWorkflowCodeExecutor:
|
||||
def test_execute_delegates_to_code_executor(self, monkeypatch):
|
||||
executor = node_factory.DefaultWorkflowCodeExecutor()
|
||||
execute_workflow_code_template = MagicMock(return_value={"answer": "ok"})
|
||||
monkeypatch.setattr(
|
||||
node_factory.CodeExecutor,
|
||||
"execute_workflow_code_template",
|
||||
execute_workflow_code_template,
|
||||
)
|
||||
|
||||
result = executor.execute(
|
||||
language=CodeLanguage.PYTHON3,
|
||||
code="print('ok')",
|
||||
inputs={"name": "workflow"},
|
||||
)
|
||||
|
||||
assert result == {"answer": "ok"}
|
||||
execute_workflow_code_template.assert_called_once_with(
|
||||
language=CodeLanguage.PYTHON3,
|
||||
code="print('ok')",
|
||||
inputs={"name": "workflow"},
|
||||
)
|
||||
|
||||
def test_is_execution_error_checks_code_execution_error_type(self):
|
||||
executor = node_factory.DefaultWorkflowCodeExecutor()
|
||||
|
||||
assert executor.is_execution_error(node_factory.CodeExecutionError("boom")) is True
|
||||
assert executor.is_execution_error(RuntimeError("boom")) is False
|
||||
|
||||
|
||||
class TestDifyNodeFactoryInit:
|
||||
def test_init_builds_default_dependencies(self):
|
||||
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
||||
graph_runtime_state = sentinel.graph_runtime_state
|
||||
dify_context = SimpleNamespace(tenant_id="tenant-id")
|
||||
template_renderer = sentinel.template_renderer
|
||||
rag_retrieval = sentinel.rag_retrieval
|
||||
unstructured_api_config = sentinel.unstructured_api_config
|
||||
http_request_config = sentinel.http_request_config
|
||||
credentials_provider = sentinel.credentials_provider
|
||||
model_factory = sentinel.model_factory
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
node_factory.DifyNodeFactory,
|
||||
"_resolve_dify_context",
|
||||
return_value=dify_context,
|
||||
) as resolve_dify_context,
|
||||
patch.object(
|
||||
node_factory,
|
||||
"CodeExecutorJinja2TemplateRenderer",
|
||||
return_value=template_renderer,
|
||||
) as renderer_factory,
|
||||
patch.object(node_factory, "DatasetRetrieval", return_value=rag_retrieval),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"UnstructuredApiConfig",
|
||||
return_value=unstructured_api_config,
|
||||
),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"build_http_request_config",
|
||||
return_value=http_request_config,
|
||||
),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"build_dify_model_access",
|
||||
return_value=(credentials_provider, model_factory),
|
||||
) as build_dify_model_access,
|
||||
):
|
||||
factory = node_factory.DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
|
||||
build_dify_model_access.assert_called_once_with("tenant-id")
|
||||
renderer_factory.assert_called_once()
|
||||
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
|
||||
assert factory.graph_init_params is graph_init_params
|
||||
assert factory.graph_runtime_state is graph_runtime_state
|
||||
assert factory._dify_context is dify_context
|
||||
assert factory._template_renderer is template_renderer
|
||||
assert factory._rag_retrieval is rag_retrieval
|
||||
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
|
||||
assert factory._http_request_config is http_request_config
|
||||
assert factory._llm_credentials_provider is credentials_provider
|
||||
assert factory._llm_model_factory is model_factory
|
||||
|
||||
|
||||
class TestDifyNodeFactoryResolveContext:
|
||||
def test_requires_reserved_context_key(self):
|
||||
with pytest.raises(ValueError, match=DIFY_RUN_CONTEXT_KEY):
|
||||
node_factory.DifyNodeFactory._resolve_dify_context({})
|
||||
|
||||
def test_returns_existing_dify_context(self):
|
||||
dify_context = DifyRunContext(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
result = node_factory.DifyNodeFactory._resolve_dify_context({DIFY_RUN_CONTEXT_KEY: dify_context})
|
||||
|
||||
assert result is dify_context
|
||||
|
||||
def test_validates_mapping_context(self):
|
||||
raw_context = {
|
||||
DIFY_RUN_CONTEXT_KEY: {
|
||||
"tenant_id": "tenant-id",
|
||||
"app_id": "app-id",
|
||||
"user_id": "user-id",
|
||||
"user_from": UserFrom.ACCOUNT,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
}
|
||||
|
||||
result = node_factory.DifyNodeFactory._resolve_dify_context(raw_context)
|
||||
|
||||
assert isinstance(result, DifyRunContext)
|
||||
assert result.tenant_id == "tenant-id"
|
||||
|
||||
|
||||
class TestDifyNodeFactoryCreateNode:
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
factory = object.__new__(node_factory.DifyNodeFactory)
|
||||
factory.graph_init_params = sentinel.graph_init_params
|
||||
factory.graph_runtime_state = sentinel.graph_runtime_state
|
||||
factory._dify_context = SimpleNamespace(tenant_id="tenant-id", app_id="app-id")
|
||||
factory._code_executor = sentinel.code_executor
|
||||
factory._code_limits = sentinel.code_limits
|
||||
factory._template_renderer = sentinel.template_renderer
|
||||
factory._template_transform_max_output_length = 2048
|
||||
factory._http_request_http_client = sentinel.http_client
|
||||
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
|
||||
factory._http_request_file_manager = sentinel.file_manager
|
||||
factory._rag_retrieval = sentinel.rag_retrieval
|
||||
factory._document_extractor_unstructured_api_config = sentinel.unstructured_api_config
|
||||
factory._http_request_config = sentinel.http_request_config
|
||||
factory._llm_credentials_provider = sentinel.credentials_provider
|
||||
factory._llm_model_factory = sentinel.model_factory
|
||||
return factory
|
||||
|
||||
def test_rejects_unknown_node_type(self, factory):
|
||||
with pytest.raises(ValueError, match="Input should be"):
|
||||
factory.create_node({"id": "node-id", "data": {"type": "missing"}})
|
||||
|
||||
def test_rejects_missing_class_mapping(self, monkeypatch, factory):
|
||||
monkeypatch.setattr(node_factory, "NODE_TYPE_CLASSES_MAPPING", {})
|
||||
|
||||
with pytest.raises(ValueError, match="No class mapping found for node type: start"):
|
||||
factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}})
|
||||
|
||||
def test_rejects_missing_latest_class(self, monkeypatch, factory):
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.START: {node_factory.LATEST_VERSION: None}},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No latest version class found for node type: start"):
|
||||
factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}})
|
||||
|
||||
def test_uses_version_specific_class_when_available(self, monkeypatch, factory):
|
||||
matched_node = sentinel.matched_node
|
||||
latest_node_class = MagicMock(return_value=sentinel.latest_node)
|
||||
matched_node_class = MagicMock(return_value=matched_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{
|
||||
NodeType.START: {
|
||||
node_factory.LATEST_VERSION: latest_node_class,
|
||||
"9": matched_node_class,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}})
|
||||
|
||||
assert result is matched_node
|
||||
matched_node_class.assert_called_once()
|
||||
kwargs = matched_node_class.call_args.kwargs
|
||||
assert kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
latest_node_class.assert_not_called()
|
||||
|
||||
def test_falls_back_to_latest_class_when_version_specific_mapping_is_missing(self, monkeypatch, factory):
|
||||
latest_node = sentinel.latest_node
|
||||
latest_node_class = MagicMock(return_value=latest_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.START: {node_factory.LATEST_VERSION: latest_node_class}},
|
||||
)
|
||||
|
||||
result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}})
|
||||
|
||||
assert result is latest_node
|
||||
latest_node_class.assert_called_once()
|
||||
kwargs = latest_node_class.call_args.kwargs
|
||||
assert kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name"),
|
||||
[
|
||||
(NodeType.CODE, "CodeNode"),
|
||||
(NodeType.TEMPLATE_TRANSFORM, "TemplateTransformNode"),
|
||||
(NodeType.HTTP_REQUEST, "HttpRequestNode"),
|
||||
(NodeType.HUMAN_INPUT, "HumanInputNode"),
|
||||
(NodeType.KNOWLEDGE_INDEX, "KnowledgeIndexNode"),
|
||||
(NodeType.DATASOURCE, "DatasourceNode"),
|
||||
(NodeType.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"),
|
||||
(NodeType.DOCUMENT_EXTRACTOR, "DocumentExtractorNode"),
|
||||
],
|
||||
)
|
||||
return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
def test_creates_specialized_nodes(self, monkeypatch, factory, node_type, constructor_name):
|
||||
created_node = object()
|
||||
constructor = MagicMock(name=constructor_name, return_value=created_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{node_type: {node_factory.LATEST_VERSION: constructor}},
|
||||
)
|
||||
|
||||
if constructor_name == "HumanInputNode":
|
||||
form_repository = sentinel.form_repository
|
||||
form_repository_impl = MagicMock(return_value=form_repository)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"HumanInputFormRepositoryImpl",
|
||||
form_repository_impl,
|
||||
)
|
||||
elif constructor_name == "KnowledgeIndexNode":
|
||||
index_processor = sentinel.index_processor
|
||||
summary_index = sentinel.summary_index
|
||||
monkeypatch.setattr(node_factory, "IndexProcessor", MagicMock(return_value=index_processor))
|
||||
monkeypatch.setattr(node_factory, "SummaryIndex", MagicMock(return_value=summary_index))
|
||||
|
||||
node_config = {"id": "node-id", "data": {"type": node_type.value}}
|
||||
result = factory.create_node(node_config)
|
||||
|
||||
assert result is created_node
|
||||
kwargs = constructor.call_args.kwargs
|
||||
assert kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
|
||||
if constructor_name == "CodeNode":
|
||||
assert kwargs["code_executor"] is sentinel.code_executor
|
||||
assert kwargs["code_limits"] is sentinel.code_limits
|
||||
elif constructor_name == "TemplateTransformNode":
|
||||
assert kwargs["template_renderer"] is sentinel.template_renderer
|
||||
assert kwargs["max_output_length"] == 2048
|
||||
elif constructor_name == "HttpRequestNode":
|
||||
assert kwargs["http_request_config"] is sentinel.http_request_config
|
||||
assert kwargs["http_client"] is sentinel.http_client
|
||||
assert kwargs["tool_file_manager_factory"] is sentinel.tool_file_manager_factory
|
||||
assert kwargs["file_manager"] is sentinel.file_manager
|
||||
elif constructor_name == "HumanInputNode":
|
||||
assert kwargs["form_repository"] is form_repository
|
||||
form_repository_impl.assert_called_once_with(tenant_id="tenant-id")
|
||||
elif constructor_name == "KnowledgeIndexNode":
|
||||
assert kwargs["index_processor"] is index_processor
|
||||
assert kwargs["summary_index_service"] is summary_index
|
||||
elif constructor_name == "DatasourceNode":
|
||||
assert kwargs["datasource_manager"] is node_factory.DatasourceManager
|
||||
elif constructor_name == "KnowledgeRetrievalNode":
|
||||
assert kwargs["rag_retrieval"] is sentinel.rag_retrieval
|
||||
elif constructor_name == "DocumentExtractorNode":
|
||||
assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config
|
||||
assert kwargs["http_client"] is sentinel.http_client
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||
[
|
||||
(NodeType.LLM, "LLMNode", {"http_client": sentinel.http_client}),
|
||||
(NodeType.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
|
||||
(NodeType.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
|
||||
],
|
||||
)
|
||||
def test_creates_model_backed_nodes(
|
||||
self,
|
||||
monkeypatch,
|
||||
factory,
|
||||
node_type,
|
||||
constructor_name,
|
||||
expected_extra_kwargs,
|
||||
):
|
||||
created_node = object()
|
||||
constructor = MagicMock(name=constructor_name, return_value=created_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{node_type: {node_factory.LATEST_VERSION: constructor}},
|
||||
)
|
||||
llm_init_kwargs = {
|
||||
"credentials_provider": sentinel.credentials_provider,
|
||||
"model_factory": sentinel.model_factory,
|
||||
"model_instance": sentinel.model_instance,
|
||||
"memory": sentinel.memory,
|
||||
**expected_extra_kwargs,
|
||||
}
|
||||
build_llm_init_kwargs = MagicMock(return_value=llm_init_kwargs)
|
||||
factory._build_llm_compatible_node_init_kwargs = build_llm_init_kwargs
|
||||
|
||||
node_config = {"id": "node-id", "data": {"type": node_type.value}}
|
||||
result = factory.create_node(node_config)
|
||||
|
||||
assert result is created_node
|
||||
build_llm_init_kwargs.assert_called_once()
|
||||
helper_kwargs = build_llm_init_kwargs.call_args.kwargs
|
||||
assert helper_kwargs["node_class"] is constructor
|
||||
assert isinstance(helper_kwargs["node_data"], BaseNodeData)
|
||||
assert helper_kwargs["node_data"].type == node_type
|
||||
assert helper_kwargs["include_http_client"] is (node_type != NodeType.PARAMETER_EXTRACTOR)
|
||||
|
||||
constructor_kwargs = constructor.call_args.kwargs
|
||||
assert constructor_kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert constructor_kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider
|
||||
assert constructor_kwargs["model_factory"] is sentinel.model_factory
|
||||
assert constructor_kwargs["model_instance"] is sentinel.model_instance
|
||||
assert constructor_kwargs["memory"] is sentinel.memory
|
||||
for key, value in expected_extra_kwargs.items():
|
||||
assert constructor_kwargs[key] is value
|
||||
|
||||
|
||||
def test_create_node_uses_declared_node_data_type_for_llm_validation(monkeypatch):
|
||||
class _FactoryLLMNodeData(LLMNodeData):
|
||||
pass
|
||||
class TestDifyNodeFactoryModelInstance:
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
factory = object.__new__(node_factory.DifyNodeFactory)
|
||||
factory._llm_credentials_provider = MagicMock()
|
||||
factory._llm_model_factory = MagicMock()
|
||||
return factory
|
||||
|
||||
llm_node_config = {
|
||||
"id": "llm-node",
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "LLM",
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o-mini",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"prompt_template": [],
|
||||
"context": {
|
||||
"enabled": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
graph_config = {"nodes": [llm_node_config], "edges": []}
|
||||
factory = _build_factory(graph_config)
|
||||
captured: dict[str, object] = {}
|
||||
@pytest.fixture
|
||||
def llm_model_setup(self, factory):
|
||||
def _configure(
|
||||
*,
|
||||
completion_params=None,
|
||||
has_provider_model=True,
|
||||
model_schema=sentinel.model_schema,
|
||||
):
|
||||
credentials = {"api_key": "secret"}
|
||||
node_data_model = SimpleNamespace(
|
||||
provider="provider",
|
||||
name="model",
|
||||
mode="chat",
|
||||
completion_params=completion_params or {},
|
||||
)
|
||||
node_data = SimpleNamespace(model=node_data_model)
|
||||
provider_model = MagicMock() if has_provider_model else None
|
||||
provider_model_bundle = SimpleNamespace(
|
||||
configuration=SimpleNamespace(get_provider_model=MagicMock(return_value=provider_model))
|
||||
)
|
||||
model_type_instance = MagicMock()
|
||||
model_type_instance.get_model_schema.return_value = model_schema
|
||||
model_instance = SimpleNamespace(
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
model_type_instance=model_type_instance,
|
||||
provider=None,
|
||||
model_name=None,
|
||||
credentials=None,
|
||||
parameters=None,
|
||||
stop=None,
|
||||
)
|
||||
factory._llm_credentials_provider.fetch.return_value = credentials
|
||||
factory._llm_model_factory.init_model_instance.return_value = model_instance
|
||||
return SimpleNamespace(
|
||||
node_data=node_data,
|
||||
credentials=credentials,
|
||||
provider_model=provider_model,
|
||||
model_type_instance=model_type_instance,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(LLMNode, "_node_data_type", _FactoryLLMNodeData)
|
||||
return _configure
|
||||
|
||||
def _capture_model_instance(self: DifyNodeFactory, node_data: object) -> ModelInstance:
|
||||
captured["node_data"] = node_data
|
||||
return object() # type: ignore[return-value]
|
||||
def test_requires_llm_mode(self, factory):
|
||||
node_data = SimpleNamespace(
|
||||
model=SimpleNamespace(
|
||||
provider="provider",
|
||||
name="model",
|
||||
mode="",
|
||||
completion_params={},
|
||||
)
|
||||
)
|
||||
|
||||
def _capture_memory(
|
||||
self: DifyNodeFactory,
|
||||
*,
|
||||
node_data: object,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
captured["memory_node_data"] = node_data
|
||||
with pytest.raises(node_factory.LLMModeRequiredError, match="LLM mode is required"):
|
||||
factory._build_model_instance_for_llm_node(node_data)
|
||||
|
||||
monkeypatch.setattr(DifyNodeFactory, "_build_model_instance_for_llm_node", _capture_model_instance)
|
||||
monkeypatch.setattr(DifyNodeFactory, "_build_memory_for_llm_node", _capture_memory)
|
||||
def test_raises_when_provider_model_is_missing(self, factory, llm_model_setup):
|
||||
setup = llm_model_setup(has_provider_model=False)
|
||||
|
||||
node = factory.create_node(llm_node_config)
|
||||
with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"):
|
||||
factory._build_model_instance_for_llm_node(setup.node_data)
|
||||
|
||||
assert isinstance(captured["node_data"], _FactoryLLMNodeData)
|
||||
assert isinstance(captured["memory_node_data"], _FactoryLLMNodeData)
|
||||
assert isinstance(node.node_data, _FactoryLLMNodeData)
|
||||
def test_raises_when_model_schema_is_missing(self, factory, llm_model_setup):
|
||||
setup = llm_model_setup(model_schema=None)
|
||||
|
||||
with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"):
|
||||
factory._build_model_instance_for_llm_node(setup.node_data)
|
||||
|
||||
setup.provider_model.raise_for_status.assert_called_once()
|
||||
|
||||
def test_builds_model_instance_and_normalizes_stop_tokens(self, factory, llm_model_setup):
|
||||
setup = llm_model_setup(
|
||||
completion_params={"temperature": 0.3, "stop": "not-a-list"},
|
||||
model_schema={"schema": "value"},
|
||||
)
|
||||
|
||||
result = factory._build_model_instance_for_llm_node(setup.node_data)
|
||||
|
||||
assert result is setup.model_instance
|
||||
assert result.provider == "provider"
|
||||
assert result.model_name == "model"
|
||||
assert result.credentials == setup.credentials
|
||||
assert result.parameters == {"temperature": 0.3}
|
||||
assert result.stop == ()
|
||||
assert result.model_type_instance is setup.model_type_instance
|
||||
setup.provider_model.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
class TestDifyNodeFactoryMemory:
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
factory = object.__new__(node_factory.DifyNodeFactory)
|
||||
factory._dify_context = SimpleNamespace(app_id="app-id")
|
||||
factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock())
|
||||
return factory
|
||||
|
||||
def test_returns_none_when_memory_is_not_configured(self, factory):
|
||||
result = factory._build_memory_for_llm_node(
|
||||
node_data=SimpleNamespace(memory=None),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
factory.graph_runtime_state.variable_pool.get.assert_not_called()
|
||||
|
||||
def test_uses_string_segment_conversation_id(self, monkeypatch, factory):
|
||||
memory_config = sentinel.memory_config
|
||||
factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="conversation-id")
|
||||
fetch_memory = MagicMock(return_value=sentinel.memory)
|
||||
monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory)
|
||||
|
||||
result = factory._build_memory_for_llm_node(
|
||||
node_data=SimpleNamespace(memory=memory_config),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is sentinel.memory
|
||||
factory.graph_runtime_state.variable_pool.get.assert_called_once_with(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
fetch_memory.assert_called_once_with(
|
||||
conversation_id="conversation-id",
|
||||
app_id="app-id",
|
||||
node_data_memory=memory_config,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
def test_ignores_non_string_segment_conversation_ids(self, monkeypatch, factory):
|
||||
memory_config = sentinel.memory_config
|
||||
factory.graph_runtime_state.variable_pool.get.return_value = sentinel.segment
|
||||
fetch_memory = MagicMock(return_value=sentinel.memory)
|
||||
monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory)
|
||||
|
||||
result = factory._build_memory_for_llm_node(
|
||||
node_data=SimpleNamespace(memory=memory_config),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is sentinel.memory
|
||||
fetch_memory.assert_called_once_with(
|
||||
conversation_id=None,
|
||||
app_id="app-id",
|
||||
node_data_memory=memory_config,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
@ -0,0 +1,656 @@
|
||||
from collections import UserString
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.workflow import workflow_entry
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.graph_events import GraphRunFailedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.runtime import ChildGraphNotFoundError
|
||||
|
||||
|
||||
def _build_typed_node_config(node_type: NodeType):
|
||||
return NodeConfigDictAdapter.validate_python({"id": "node-id", "data": {"type": node_type}})
|
||||
|
||||
|
||||
class TestWorkflowChildEngineBuilder:
|
||||
@pytest.mark.parametrize(
|
||||
("graph_config", "node_id", "expected"),
|
||||
[
|
||||
({"nodes": [{"id": "root"}]}, "root", True),
|
||||
({"nodes": [{"id": "root"}]}, "other", False),
|
||||
({"nodes": "invalid"}, "root", None),
|
||||
({"nodes": ["invalid"]}, "root", None),
|
||||
],
|
||||
)
|
||||
def test_has_node_id(self, graph_config, node_id, expected):
|
||||
result = workflow_entry._WorkflowChildEngineBuilder._has_node_id(graph_config, node_id)
|
||||
|
||||
assert result is expected
|
||||
|
||||
def test_build_child_engine_raises_when_root_node_is_missing(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
|
||||
with patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory):
|
||||
with pytest.raises(ChildGraphNotFoundError, match="child graph root node 'missing' not found"):
|
||||
builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
graph_config={"nodes": []},
|
||||
root_node_id="missing",
|
||||
)
|
||||
|
||||
def test_build_child_engine_constructs_graph_engine_and_layers(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
child_graph = sentinel.child_graph
|
||||
child_engine = MagicMock()
|
||||
quota_layer = sentinel.quota_layer
|
||||
additional_layers = [sentinel.layer_one, sentinel.layer_two]
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory) as dify_node_factory,
|
||||
patch.object(workflow_entry.Graph, "init", return_value=child_graph) as graph_init,
|
||||
patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls,
|
||||
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
|
||||
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=quota_layer),
|
||||
):
|
||||
result = builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
graph_config={"nodes": [{"id": "root"}]},
|
||||
root_node_id="root",
|
||||
layers=additional_layers,
|
||||
)
|
||||
|
||||
assert result is child_engine
|
||||
dify_node_factory.assert_called_once_with(
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
graph_init.assert_called_once_with(
|
||||
graph_config={"nodes": [{"id": "root"}]},
|
||||
node_factory=sentinel.factory,
|
||||
root_node_id="root",
|
||||
)
|
||||
graph_engine_cls.assert_called_once_with(
|
||||
workflow_id="workflow-id",
|
||||
graph=child_graph,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
command_channel=sentinel.command_channel,
|
||||
config=sentinel.graph_engine_config,
|
||||
child_engine_builder=builder,
|
||||
)
|
||||
assert child_engine.layer.call_args_list == [
|
||||
((quota_layer,), {}),
|
||||
((sentinel.layer_one,), {}),
|
||||
((sentinel.layer_two,), {}),
|
||||
]
|
||||
|
||||
|
||||
class TestWorkflowEntryInit:
|
||||
def test_rejects_call_depth_above_limit(self):
|
||||
call_depth = workflow_entry.dify_config.WORKFLOW_CALL_MAX_DEPTH + 1
|
||||
|
||||
with pytest.raises(ValueError, match="Max workflow call depth"):
|
||||
workflow_entry.WorkflowEntry(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
graph=sentinel.graph,
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=call_depth,
|
||||
variable_pool=sentinel.variable_pool,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
|
||||
def test_applies_debug_and_observability_layers(self):
|
||||
graph_engine = MagicMock()
|
||||
debug_layer = sentinel.debug_layer
|
||||
execution_limits_layer = sentinel.execution_limits_layer
|
||||
llm_quota_layer = sentinel.llm_quota_layer
|
||||
observability_layer = sentinel.observability_layer
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry.dify_config, "DEBUG", True),
|
||||
patch.object(workflow_entry.dify_config, "ENABLE_OTEL", False),
|
||||
patch.object(workflow_entry, "is_instrument_flag_enabled", return_value=True),
|
||||
patch.object(workflow_entry, "GraphEngine", return_value=graph_engine) as graph_engine_cls,
|
||||
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
|
||||
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
|
||||
patch.object(workflow_entry, "DebugLoggingLayer", return_value=debug_layer) as debug_logging_layer,
|
||||
patch.object(
|
||||
workflow_entry,
|
||||
"ExecutionLimitsLayer",
|
||||
return_value=execution_limits_layer,
|
||||
) as execution_limits_layer_cls,
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer),
|
||||
patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer),
|
||||
):
|
||||
entry = workflow_entry.WorkflowEntry(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id-123456",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
graph=sentinel.graph,
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
variable_pool=sentinel.variable_pool,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
command_channel=None,
|
||||
)
|
||||
|
||||
assert entry.command_channel is sentinel.command_channel
|
||||
graph_engine_cls.assert_called_once_with(
|
||||
workflow_id="workflow-id-123456",
|
||||
graph=sentinel.graph,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
command_channel=sentinel.command_channel,
|
||||
config=sentinel.graph_engine_config,
|
||||
child_engine_builder=entry._child_engine_builder,
|
||||
)
|
||||
debug_logging_layer.assert_called_once_with(
|
||||
level="DEBUG",
|
||||
include_inputs=True,
|
||||
include_outputs=True,
|
||||
include_process_data=False,
|
||||
logger_name="GraphEngine.Debug.workflow",
|
||||
)
|
||||
execution_limits_layer_cls.assert_called_once_with(
|
||||
max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
)
|
||||
assert graph_engine.layer.call_args_list == [
|
||||
((debug_layer,), {}),
|
||||
((execution_limits_layer,), {}),
|
||||
((llm_quota_layer,), {}),
|
||||
((observability_layer,), {}),
|
||||
]
|
||||
|
||||
|
||||
class TestWorkflowEntryRun:
|
||||
def test_run_swallows_generate_task_stopped_errors(self):
|
||||
entry = object.__new__(workflow_entry.WorkflowEntry)
|
||||
entry.graph_engine = MagicMock()
|
||||
entry.graph_engine.run.side_effect = GenerateTaskStoppedError()
|
||||
|
||||
assert list(entry.run()) == []
|
||||
|
||||
def test_run_emits_failed_event_for_unexpected_errors(self):
|
||||
entry = object.__new__(workflow_entry.WorkflowEntry)
|
||||
entry.graph_engine = MagicMock()
|
||||
entry.graph_engine.run.side_effect = RuntimeError("boom")
|
||||
|
||||
events = list(entry.run())
|
||||
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], GraphRunFailedEvent)
|
||||
assert events[0].error == "boom"
|
||||
|
||||
|
||||
class TestWorkflowEntrySingleStepRun:
|
||||
def test_uses_empty_mapping_when_selector_extraction_is_not_implemented(self):
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "fake"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
) as mapping_user_inputs_to_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
return_value=iter(["event"]),
|
||||
),
|
||||
):
|
||||
dify_node_factory.return_value.create_node.return_value = FakeNode()
|
||||
workflow = SimpleNamespace(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
id="workflow-id",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START),
|
||||
)
|
||||
|
||||
node, generator = workflow_entry.WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id="node-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
)
|
||||
|
||||
assert node.id == "node-id"
|
||||
assert list(generator) == ["event"]
|
||||
load_into_variable_pool.assert_called_once_with(
|
||||
variable_loader=workflow_entry.DUMMY_VARIABLE_LOADER,
|
||||
variable_pool=sentinel.variable_pool,
|
||||
variable_mapping={},
|
||||
user_inputs={"question": "hello"},
|
||||
)
|
||||
mapping_user_inputs_to_variable_pool.assert_called_once_with(
|
||||
variable_mapping={},
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_skips_user_input_mapping_for_datasource_nodes(self):
|
||||
class FakeDatasourceNode:
|
||||
id = "node-id"
|
||||
node_type = "datasource"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {"question": ["node", "question"]}
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
) as mapping_user_inputs_to_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
return_value=iter(["event"]),
|
||||
),
|
||||
):
|
||||
dify_node_factory.return_value.create_node.return_value = FakeDatasourceNode()
|
||||
workflow = SimpleNamespace(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
id="workflow-id",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.DATASOURCE),
|
||||
)
|
||||
|
||||
node, generator = workflow_entry.WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id="node-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
)
|
||||
|
||||
assert node.id == "node-id"
|
||||
assert list(generator) == ["event"]
|
||||
load_into_variable_pool.assert_called_once()
|
||||
mapping_user_inputs_to_variable_pool.assert_not_called()
|
||||
|
||||
def test_wraps_traced_node_run_failures(self):
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "fake"
|
||||
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool"),
|
||||
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
):
|
||||
dify_node_factory.return_value.create_node.return_value = FakeNode()
|
||||
workflow = SimpleNamespace(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
id="workflow-id",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START),
|
||||
)
|
||||
|
||||
with pytest.raises(WorkflowNodeRunFailedError):
|
||||
workflow_entry.WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id="node-id",
|
||||
user_id="user-id",
|
||||
user_inputs={},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowEntryHelpers:
|
||||
def test_create_single_node_graph_builds_start_edge(self):
|
||||
graph = workflow_entry.WorkflowEntry._create_single_node_graph(
|
||||
node_id="target-node",
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR},
|
||||
node_width=320,
|
||||
node_height=180,
|
||||
)
|
||||
|
||||
assert graph["nodes"][0]["id"] == "start"
|
||||
assert graph["nodes"][1]["id"] == "target-node"
|
||||
assert graph["nodes"][1]["width"] == 320
|
||||
assert graph["nodes"][1]["height"] == 180
|
||||
assert graph["edges"] == [
|
||||
{
|
||||
"source": "start",
|
||||
"target": "target-node",
|
||||
"sourceHandle": "source",
|
||||
"targetHandle": "target",
|
||||
}
|
||||
]
|
||||
|
||||
def test_run_free_node_rejects_unsupported_types(self):
|
||||
with pytest.raises(ValueError, match="Node type start not supported"):
|
||||
workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.START.value},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
def test_run_free_node_rejects_missing_node_class(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
workflow_entry,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.PARAMETER_EXTRACTOR: {"1": None}},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Node class not found for node type parameter-extractor"):
|
||||
workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
def test_run_free_node_uses_empty_mapping_when_selector_extraction_is_not_implemented(self, monkeypatch):
|
||||
class FakeNodeClass:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "parameter-extractor"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
dify_node_factory = MagicMock()
|
||||
dify_node_factory.create_node.return_value = FakeNode()
|
||||
monkeypatch.setattr(
|
||||
workflow_entry,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables),
|
||||
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls,
|
||||
patch.object(
|
||||
workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params
|
||||
) as graph_init_params,
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(
|
||||
workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}
|
||||
) as build_dify_run_context,
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
) as mapping_user_inputs_to_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
return_value=iter(["event"]),
|
||||
),
|
||||
):
|
||||
node, generator = workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
)
|
||||
|
||||
assert node.id == "node-id"
|
||||
assert list(generator) == ["event"]
|
||||
variable_pool_cls.assert_called_once_with(
|
||||
system_variables=sentinel.system_variables,
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
build_dify_run_context.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
app_id="",
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_params.assert_called_once_with(
|
||||
workflow_id="",
|
||||
graph_config=workflow_entry.WorkflowEntry._create_single_node_graph(
|
||||
"node-id", {"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"}
|
||||
),
|
||||
run_context={"_dify": "context"},
|
||||
call_depth=0,
|
||||
)
|
||||
dify_node_factory_cls.assert_called_once_with(
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
mapping_user_inputs_to_variable_pool.assert_called_once_with(
|
||||
variable_mapping={},
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_run_free_node_wraps_execution_failures(self, monkeypatch):
|
||||
class FakeNodeClass:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "parameter-extractor"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
dify_node_factory = MagicMock()
|
||||
dify_node_factory.create_node.return_value = FakeNode()
|
||||
monkeypatch.setattr(
|
||||
workflow_entry,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables),
|
||||
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool),
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory),
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(WorkflowNodeRunFailedError, match="Node Title run failed: boom"):
|
||||
workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
)
|
||||
|
||||
def test_handle_special_values_serializes_nested_files(self):
|
||||
file = File(
|
||||
tenant_id="tenant-id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.png",
|
||||
filename="image.png",
|
||||
extension=".png",
|
||||
)
|
||||
|
||||
result = workflow_entry.WorkflowEntry.handle_special_values({"file": file, "nested": {"files": [file]}})
|
||||
|
||||
assert result == {
|
||||
"file": file.to_dict(),
|
||||
"nested": {"files": [file.to_dict()]},
|
||||
}
|
||||
|
||||
def test_handle_special_values_returns_none_for_none(self):
|
||||
assert workflow_entry.WorkflowEntry._handle_special_values(None) is None
|
||||
|
||||
def test_handle_special_values_returns_scalar_as_is(self):
|
||||
assert workflow_entry.WorkflowEntry._handle_special_values("plain-text") == "plain-text"
|
||||
|
||||
|
||||
class TestMappingUserInputsBranches:
|
||||
def test_rejects_invalid_node_variable_key(self):
|
||||
class EmptySplitKey(UserString):
|
||||
def split(self, _sep=None):
|
||||
return []
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid node variable broken"):
|
||||
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping={EmptySplitKey("broken"): ["node", "input"]},
|
||||
user_inputs={},
|
||||
variable_pool=MagicMock(),
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_skips_none_user_input_when_variable_already_exists(self):
|
||||
variable_pool = MagicMock()
|
||||
variable_pool.get.return_value = None
|
||||
|
||||
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping={"node.input": ["target", "input"]},
|
||||
user_inputs={"node.input": None},
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
variable_pool.add.assert_not_called()
|
||||
|
||||
def test_merges_structured_output_values(self):
|
||||
variable_pool = MagicMock()
|
||||
variable_pool.get.side_effect = [
|
||||
None,
|
||||
SimpleNamespace(value={"existing": "value"}),
|
||||
]
|
||||
|
||||
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping={"node.answer": ["target", "structured_output", "answer"]},
|
||||
user_inputs={"node.answer": "new-value"},
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
variable_pool.add.assert_called_once_with(
|
||||
["target", "structured_output"],
|
||||
{"existing": "value", "answer": "new-value"},
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowEntryTracing:
|
||||
def test_traced_node_run_reports_success(self):
|
||||
layer = MagicMock()
|
||||
|
||||
class FakeNode:
|
||||
def ensure_execution_id(self):
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
yield "event"
|
||||
|
||||
with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer):
|
||||
events = list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode()))
|
||||
|
||||
assert events == ["event"]
|
||||
layer.on_graph_start.assert_called_once_with()
|
||||
layer.on_node_run_start.assert_called_once()
|
||||
layer.on_node_run_end.assert_called_once_with(
|
||||
layer.on_node_run_start.call_args.args[0],
|
||||
None,
|
||||
)
|
||||
|
||||
def test_traced_node_run_reports_errors(self):
|
||||
layer = MagicMock()
|
||||
|
||||
class FakeNode:
|
||||
def ensure_execution_id(self):
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
raise RuntimeError("boom")
|
||||
yield
|
||||
|
||||
with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode()))
|
||||
|
||||
assert isinstance(layer.on_node_run_end.call_args.args[1], RuntimeError)
|
||||
@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def jina_module() -> ModuleType:
|
||||
"""
|
||||
Load `api/services/auth/jina.py` as a standalone module.
|
||||
|
||||
This repo contains both `services/auth/jina.py` and a package at
|
||||
`services/auth/jina/`, so importing `services.auth.jina` can be ambiguous.
|
||||
"""
|
||||
|
||||
module_path = Path(__file__).resolve().parents[4] / "services" / "auth" / "jina.py"
|
||||
# Use a stable module name so pytest-cov can target it with `--cov=services.auth.jina_file`.
|
||||
spec = importlib.util.spec_from_file_location("services.auth.jina_file", module_path)
|
||||
assert spec is not None
|
||||
assert spec.loader is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _credentials(api_key: str | None = "test_api_key_123", auth_type: str = "bearer") -> dict:
|
||||
config: dict = {} if api_key is None else {"api_key": api_key}
|
||||
return {"auth_type": auth_type, "config": config}
|
||||
|
||||
|
||||
def test_init_valid_bearer_credentials(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials())
|
||||
assert auth.api_key == "test_api_key_123"
|
||||
assert auth.credentials["auth_type"] == "bearer"
|
||||
|
||||
|
||||
def test_init_rejects_invalid_auth_type(jina_module: ModuleType) -> None:
|
||||
with pytest.raises(ValueError, match="Invalid auth type.*Bearer"):
|
||||
jina_module.JinaAuth(_credentials(auth_type="basic"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("credentials", [{"auth_type": "bearer", "config": {}}, {"auth_type": "bearer"}])
|
||||
def test_init_requires_api_key(jina_module: ModuleType, credentials: dict) -> None:
|
||||
with pytest.raises(ValueError, match="No API key provided"):
|
||||
jina_module.JinaAuth(credentials)
|
||||
|
||||
|
||||
def test_prepare_headers_includes_bearer_api_key(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
assert auth._prepare_headers() == {"Content-Type": "application/json", "Authorization": "Bearer k"}
|
||||
|
||||
|
||||
def test_post_request_calls_httpx(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
post_mock = MagicMock(name="httpx.post")
|
||||
monkeypatch.setattr(jina_module.httpx, "post", post_mock)
|
||||
|
||||
auth._post_request("https://r.jina.ai", {"url": "https://example.com"}, {"h": "v"})
|
||||
post_mock.assert_called_once_with("https://r.jina.ai", headers={"h": "v"}, json={"url": "https://example.com"})
|
||||
|
||||
|
||||
def test_validate_credentials_success(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
post_mock = MagicMock(return_value=response)
|
||||
monkeypatch.setattr(jina_module.httpx, "post", post_mock)
|
||||
|
||||
assert auth.validate_credentials() is True
|
||||
post_mock.assert_called_once_with(
|
||||
"https://r.jina.ai",
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer k"},
|
||||
json={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_non_200_raises_via_handle_error(
|
||||
jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
|
||||
response = MagicMock()
|
||||
response.status_code = 402
|
||||
response.json.return_value = {"error": "Payment required"}
|
||||
monkeypatch.setattr(jina_module.httpx, "post", MagicMock(return_value=response))
|
||||
|
||||
with pytest.raises(Exception, match="Status code: 402.*Payment required"):
|
||||
auth.validate_credentials()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("status_code", [402, 409, 500])
|
||||
def test_handle_error_statuses_use_response_json(jina_module: ModuleType, status_code: int) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
response.json.return_value = {"error": "boom"}
|
||||
|
||||
with pytest.raises(Exception, match=f"Status code: {status_code}.*boom"):
|
||||
auth._handle_error(response)
|
||||
|
||||
|
||||
def test_handle_error_statuses_default_unknown_error(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
response = MagicMock()
|
||||
response.status_code = 402
|
||||
response.json.return_value = {}
|
||||
|
||||
with pytest.raises(Exception, match="Unknown error occurred"):
|
||||
auth._handle_error(response)
|
||||
|
||||
|
||||
def test_handle_error_with_text_json_body(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
response = MagicMock()
|
||||
response.status_code = 403
|
||||
response.text = '{"error": "Forbidden"}'
|
||||
|
||||
with pytest.raises(Exception, match="Status code: 403.*Forbidden"):
|
||||
auth._handle_error(response)
|
||||
|
||||
|
||||
def test_handle_error_with_text_json_body_missing_error(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
response = MagicMock()
|
||||
response.status_code = 403
|
||||
response.text = "{}"
|
||||
|
||||
with pytest.raises(Exception, match="Unknown error occurred"):
|
||||
auth._handle_error(response)
|
||||
|
||||
|
||||
def test_handle_error_without_text_raises_unexpected(jina_module: ModuleType) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
response = MagicMock()
|
||||
response.status_code = 404
|
||||
response.text = ""
|
||||
|
||||
with pytest.raises(Exception, match="Unexpected error occurred.*404"):
|
||||
auth._handle_error(response)
|
||||
|
||||
|
||||
def test_validate_credentials_propagates_network_errors(
|
||||
jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
auth = jina_module.JinaAuth(_credentials(api_key="k"))
|
||||
monkeypatch.setattr(jina_module.httpx, "post", MagicMock(side_effect=httpx.ConnectError("boom")))
|
||||
|
||||
with pytest.raises(httpx.ConnectError, match="boom"):
|
||||
auth.validate_credentials()
|
||||
381
api/tests/unit_tests/services/test_ops_service.py
Normal file
381
api/tests/unit_tests/services/test_ops_service.py
Normal file
@ -0,0 +1,381 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import TracingProviderEnum
|
||||
from models.model import App, TraceAppConfig
|
||||
from services.ops_service import OpsService
|
||||
|
||||
|
||||
class TestOpsService:
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_db.session.query.assert_called_with(TraceAppConfig)
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, None]
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
assert mock_db.session.query.call_count == 2
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = None
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Tracing config cannot be None."):
|
||||
OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "default_url"),
|
||||
[
|
||||
("arize", "https://app.arize.com/"),
|
||||
("phoenix", "https://app.phoenix.arize.com/projects/"),
|
||||
("langsmith", "https://smith.langchain.com/"),
|
||||
("opik", "https://www.comet.com/opik/"),
|
||||
("weave", "https://wandb.ai/"),
|
||||
("aliyun", "https://arms.console.aliyun.com/"),
|
||||
("tencent", "https://console.cloud.tencent.com/apm"),
|
||||
("mlflow", "http://localhost:5000/"),
|
||||
("databricks", "https://www.databricks.com/"),
|
||||
],
|
||||
)
|
||||
def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", provider)
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == default_url
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
"provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"]
|
||||
)
|
||||
def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url"
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", provider)
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "success_url"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "langfuse")
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app]
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "langfuse")
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {})
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "Invalid tracing provider: invalid_provider"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.LANGFUSE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"})
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "Invalid Credentials"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "config"),
|
||||
[
|
||||
(TracingProviderEnum.ARIZE, {}),
|
||||
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
|
||||
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
|
||||
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
|
||||
],
|
||||
)
|
||||
def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config):
|
||||
# Arrange
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig)
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, config)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.LANGFUSE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app]
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config(
|
||||
"app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig)
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, None]
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app]
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
|
||||
# Act
|
||||
# 'project' is in other_keys for Arize
|
||||
# provide an empty string for the project in the tracing_config
|
||||
# create_tracing_app_config will replace it with the default from the model
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""})
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url"
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app]
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"}
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_db.session.add.assert_called()
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
|
||||
OpsService.update_tracing_app_config("app_id", "invalid_provider", {})
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, None]
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app]
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid Credentials"):
|
||||
OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
current_config.to_dict.return_value = {"some": "data"}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app]
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result == {"some": "data"}
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
def test_delete_tracing_app_config_no_config(self, mock_db):
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
def test_delete_tracing_app_config_success(self, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = trace_config
|
||||
|
||||
# Act
|
||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_db.session.delete.assert_called_with(trace_config)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
1329
api/tests/unit_tests/services/test_summary_index_service.py
Normal file
1329
api/tests/unit_tests/services/test_summary_index_service.py
Normal file
File diff suppressed because it is too large
Load Diff
704
api/tests/unit_tests/services/test_vector_service.py
Normal file
704
api/tests/unit_tests/services/test_vector_service.py
Normal file
@ -0,0 +1,704 @@
|
||||
"""Unit tests for `api/services/vector_service.py`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import services.vector_service as vector_service_module
|
||||
from services.vector_service import VectorService
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _UploadFileStub:
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ChildDocStub:
|
||||
page_content: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ParentDocStub:
|
||||
children: list[_ChildDocStub]
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
*,
|
||||
indexing_technique: str = "high_quality",
|
||||
doc_form: str = "text_model",
|
||||
tenant_id: str = "tenant-1",
|
||||
dataset_id: str = "dataset-1",
|
||||
is_multimodal: bool = False,
|
||||
embedding_model_provider: str | None = "openai",
|
||||
embedding_model: str = "text-embedding",
|
||||
) -> MagicMock:
|
||||
dataset = MagicMock(name="dataset")
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.doc_form = doc_form
|
||||
dataset.indexing_technique = indexing_technique
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.embedding_model_provider = embedding_model_provider
|
||||
dataset.embedding_model = embedding_model
|
||||
return dataset
|
||||
|
||||
|
||||
def _make_segment(
|
||||
*,
|
||||
segment_id: str = "seg-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
dataset_id: str = "dataset-1",
|
||||
document_id: str = "doc-1",
|
||||
content: str = "hello",
|
||||
index_node_id: str = "node-1",
|
||||
index_node_hash: str = "hash-1",
|
||||
attachments: list[dict[str, str]] | None = None,
|
||||
) -> MagicMock:
|
||||
segment = MagicMock(name="segment")
|
||||
segment.id = segment_id
|
||||
segment.tenant_id = tenant_id
|
||||
segment.dataset_id = dataset_id
|
||||
segment.document_id = document_id
|
||||
segment.content = content
|
||||
segment.index_node_id = index_node_id
|
||||
segment.index_node_hash = index_node_hash
|
||||
segment.attachments = attachments or []
|
||||
return segment
|
||||
|
||||
|
||||
def _mock_db_session_for_update_multimodel(*, upload_files: list[_UploadFileStub] | None) -> MagicMock:
|
||||
session = MagicMock(name="session")
|
||||
|
||||
binding_query = MagicMock(name="binding_query")
|
||||
binding_query.where.return_value = binding_query
|
||||
binding_query.delete.return_value = 1
|
||||
|
||||
upload_query = MagicMock(name="upload_query")
|
||||
upload_query.where.return_value = upload_query
|
||||
upload_query.all.return_value = upload_files or []
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is vector_service_module.SegmentAttachmentBinding:
|
||||
return binding_query
|
||||
if model is vector_service_module.UploadFile:
|
||||
return upload_query
|
||||
return MagicMock(name=f"query({model})")
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
db_mock = MagicMock(name="db")
|
||||
db_mock.session = session
|
||||
return db_mock
|
||||
|
||||
|
||||
def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(is_multimodal=False)
|
||||
segment = _make_segment()
|
||||
|
||||
index_processor = MagicMock(name="index_processor")
|
||||
factory_instance = MagicMock(name="IndexProcessorFactory-instance")
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model")
|
||||
|
||||
index_processor.load.assert_called_once()
|
||||
args, kwargs = index_processor.load.call_args
|
||||
assert args[0] == dataset
|
||||
assert len(args[1]) == 1
|
||||
assert args[2] is None
|
||||
assert kwargs["with_keywords"] is True
|
||||
assert kwargs["keywords_list"] == [["k1"]]
|
||||
|
||||
|
||||
def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(is_multimodal=True)
|
||||
segment = _make_segment(
|
||||
attachments=[
|
||||
{"id": "img-1", "name": "a.png"},
|
||||
{"id": "img-2", "name": "b.png"},
|
||||
]
|
||||
)
|
||||
|
||||
index_processor = MagicMock(name="index_processor")
|
||||
factory_instance = MagicMock(name="IndexProcessorFactory-instance")
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model")
|
||||
|
||||
assert index_processor.load.call_count == 2
|
||||
first_args, first_kwargs = index_processor.load.call_args_list[0]
|
||||
assert first_args[0] == dataset
|
||||
assert len(first_args[1]) == 1
|
||||
assert first_kwargs["with_keywords"] is True
|
||||
|
||||
second_args, second_kwargs = index_processor.load.call_args_list[1]
|
||||
assert second_args[0] == dataset
|
||||
assert second_args[1] == []
|
||||
assert len(second_args[2]) == 2
|
||||
assert second_kwargs["with_keywords"] is False
|
||||
|
||||
|
||||
def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset()
|
||||
index_processor = MagicMock(name="index_processor")
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector(None, [], dataset, "text_model")
|
||||
index_processor.load.assert_not_called()
|
||||
|
||||
|
||||
def _mock_parent_child_queries(
|
||||
*,
|
||||
dataset_document: object | None,
|
||||
processing_rule: object | None,
|
||||
) -> MagicMock:
|
||||
session = MagicMock(name="session")
|
||||
|
||||
doc_query = MagicMock(name="doc_query")
|
||||
doc_query.filter_by.return_value = doc_query
|
||||
doc_query.first.return_value = dataset_document
|
||||
|
||||
rule_query = MagicMock(name="rule_query")
|
||||
rule_query.where.return_value = rule_query
|
||||
rule_query.first.return_value = processing_rule
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is vector_service_module.DatasetDocument:
|
||||
return doc_query
|
||||
if model is vector_service_module.DatasetProcessRule:
|
||||
return rule_query
|
||||
return MagicMock(name=f"query({model})")
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
db_mock = MagicMock(name="db")
|
||||
db_mock.session = session
|
||||
return db_mock
|
||||
|
||||
|
||||
def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_explicit_model(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(
|
||||
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
|
||||
embedding_model_provider="openai",
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
segment = _make_segment()
|
||||
|
||||
dataset_document = MagicMock(name="dataset_document")
|
||||
dataset_document.id = segment.document_id
|
||||
dataset_document.dataset_process_rule_id = "rule-1"
|
||||
dataset_document.doc_language = "en"
|
||||
dataset_document.created_by = "user-1"
|
||||
|
||||
processing_rule = MagicMock(name="processing_rule")
|
||||
processing_rule.to_dict.return_value = {"rules": {}}
|
||||
|
||||
monkeypatch.setattr(
|
||||
vector_service_module,
|
||||
"db",
|
||||
_mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule),
|
||||
)
|
||||
|
||||
embedding_model_instance = MagicMock(name="embedding_model_instance")
|
||||
model_manager_instance = MagicMock(name="model_manager_instance")
|
||||
model_manager_instance.get_model_instance.return_value = embedding_model_instance
|
||||
monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance))
|
||||
|
||||
generate_child_chunks_mock = MagicMock()
|
||||
monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock)
|
||||
|
||||
index_processor = MagicMock()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector(
|
||||
None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
|
||||
model_manager_instance.get_model_instance.assert_called_once()
|
||||
generate_child_chunks_mock.assert_called_once_with(
|
||||
segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
|
||||
)
|
||||
index_processor.load.assert_not_called()
|
||||
|
||||
|
||||
def test_create_segments_vector_parent_child_uses_default_embedding_model_when_provider_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(
|
||||
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
|
||||
embedding_model_provider=None,
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
segment = _make_segment()
|
||||
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.dataset_process_rule_id = "rule-1"
|
||||
dataset_document.doc_language = "en"
|
||||
dataset_document.created_by = "user-1"
|
||||
|
||||
processing_rule = MagicMock()
|
||||
processing_rule.to_dict.return_value = {"rules": {}}
|
||||
|
||||
monkeypatch.setattr(
|
||||
vector_service_module,
|
||||
"db",
|
||||
_mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule),
|
||||
)
|
||||
|
||||
embedding_model_instance = MagicMock()
|
||||
model_manager_instance = MagicMock()
|
||||
model_manager_instance.get_default_model_instance.return_value = embedding_model_instance
|
||||
monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance))
|
||||
|
||||
generate_child_chunks_mock = MagicMock()
|
||||
monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock)
|
||||
|
||||
index_processor = MagicMock()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector(
|
||||
None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
|
||||
model_manager_instance.get_default_model_instance.assert_called_once()
|
||||
generate_child_chunks_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_create_segments_vector_parent_child_missing_document_logs_warning_and_continues(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX)
|
||||
segment = _make_segment()
|
||||
|
||||
processing_rule = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
vector_service_module,
|
||||
"db",
|
||||
_mock_parent_child_queries(dataset_document=None, processing_rule=processing_rule),
|
||||
)
|
||||
|
||||
logger_mock = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "logger", logger_mock)
|
||||
|
||||
index_processor = MagicMock()
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
VectorService.create_segments_vector(
|
||||
None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
logger_mock.warning.assert_called_once()
|
||||
index_processor.load.assert_not_called()
|
||||
|
||||
|
||||
def test_create_segments_vector_parent_child_missing_processing_rule_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX)
|
||||
segment = _make_segment()
|
||||
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.dataset_process_rule_id = "rule-1"
|
||||
monkeypatch.setattr(
|
||||
vector_service_module,
|
||||
"db",
|
||||
_mock_parent_child_queries(dataset_document=dataset_document, processing_rule=None),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No processing rule found"):
|
||||
VectorService.create_segments_vector(
|
||||
None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
|
||||
|
||||
def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(
|
||||
doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX,
|
||||
indexing_technique="economy",
|
||||
)
|
||||
segment = _make_segment()
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.dataset_process_rule_id = "rule-1"
|
||||
processing_rule = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
vector_service_module,
|
||||
"db",
|
||||
_mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="not high quality"):
|
||||
VectorService.create_segments_vector(
|
||||
None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX
|
||||
)
|
||||
|
||||
|
||||
def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality")
|
||||
segment = _make_segment()
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
|
||||
VectorService.update_segment_vector(["k"], segment, dataset)
|
||||
|
||||
vector_instance.delete_by_ids.assert_called_once_with([segment.index_node_id])
|
||||
vector_instance.add_texts.assert_called_once()
|
||||
add_args, add_kwargs = vector_instance.add_texts.call_args
|
||||
assert len(add_args[0]) == 1
|
||||
assert add_kwargs["duplicate_check"] is True
|
||||
|
||||
|
||||
def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
segment = _make_segment()
|
||||
|
||||
keyword_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance))
|
||||
|
||||
VectorService.update_segment_vector(["a", "b"], segment, dataset)
|
||||
|
||||
keyword_instance.delete_by_ids.assert_called_once_with([segment.index_node_id])
|
||||
keyword_instance.add_texts.assert_called_once()
|
||||
args, kwargs = keyword_instance.add_texts.call_args
|
||||
assert len(args[0]) == 1
|
||||
assert kwargs["keywords_list"] == [["a", "b"]]
|
||||
|
||||
|
||||
def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
segment = _make_segment()
|
||||
|
||||
keyword_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance))
|
||||
|
||||
VectorService.update_segment_vector(None, segment, dataset)
|
||||
keyword_instance.add_texts.assert_called_once()
|
||||
_, kwargs = keyword_instance.add_texts.call_args
|
||||
assert "keywords_list" not in kwargs
|
||||
|
||||
|
||||
def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1")
|
||||
segment = _make_segment(segment_id="seg-1")
|
||||
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.id = segment.document_id
|
||||
dataset_document.doc_language = "en"
|
||||
dataset_document.created_by = "user-1"
|
||||
|
||||
processing_rule = MagicMock()
|
||||
processing_rule.to_dict.return_value = {"rules": {}}
|
||||
|
||||
child1 = _ChildDocStub(page_content="c1", metadata={"doc_id": "c1-id", "doc_hash": "c1-h"})
|
||||
child2 = _ChildDocStub(page_content="c2", metadata={"doc_id": "c2-id", "doc_hash": "c2-h"})
|
||||
transformed = [_ParentDocStub(children=[child1, child2])]
|
||||
|
||||
index_processor = MagicMock()
|
||||
index_processor.transform.return_value = transformed
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
child_chunk_ctor = MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(vector_service_module, "ChildChunk", child_chunk_ctor)
|
||||
|
||||
db_mock = MagicMock()
|
||||
db_mock.session.add = MagicMock()
|
||||
db_mock.session.commit = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.generate_child_chunks(
|
||||
segment=segment,
|
||||
dataset_document=dataset_document,
|
||||
dataset=dataset,
|
||||
embedding_model_instance=MagicMock(),
|
||||
processing_rule=processing_rule,
|
||||
regenerate=True,
|
||||
)
|
||||
|
||||
index_processor.clean.assert_called_once()
|
||||
_, transform_kwargs = index_processor.transform.call_args
|
||||
assert transform_kwargs["process_rule"]["rules"]["parent_mode"] == vector_service_module.ParentMode.FULL_DOC
|
||||
index_processor.load.assert_called_once()
|
||||
assert db_mock.session.add.call_count == 2
|
||||
db_mock.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(doc_form="text_model")
|
||||
segment = _make_segment()
|
||||
dataset_document = MagicMock()
|
||||
dataset_document.doc_language = "en"
|
||||
dataset_document.created_by = "user-1"
|
||||
|
||||
processing_rule = MagicMock()
|
||||
processing_rule.to_dict.return_value = {"rules": {}}
|
||||
|
||||
index_processor = MagicMock()
|
||||
index_processor.transform.return_value = [_ParentDocStub(children=[])]
|
||||
factory_instance = MagicMock()
|
||||
factory_instance.init_index_processor.return_value = index_processor
|
||||
monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance))
|
||||
|
||||
db_mock = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.generate_child_chunks(
|
||||
segment=segment,
|
||||
dataset_document=dataset_document,
|
||||
dataset=dataset,
|
||||
embedding_model_instance=MagicMock(),
|
||||
processing_rule=processing_rule,
|
||||
regenerate=False,
|
||||
)
|
||||
|
||||
index_processor.load.assert_not_called()
|
||||
db_mock.session.add.assert_not_called()
|
||||
db_mock.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality")
|
||||
child_chunk = MagicMock()
|
||||
child_chunk.content = "child"
|
||||
child_chunk.index_node_id = "id"
|
||||
child_chunk.index_node_hash = "h"
|
||||
child_chunk.document_id = "doc-1"
|
||||
child_chunk.dataset_id = "dataset-1"
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
|
||||
VectorService.create_child_chunk_vector(child_chunk, dataset)
|
||||
vector_instance.add_texts.assert_called_once()
|
||||
|
||||
|
||||
def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
vector_cls = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
|
||||
child_chunk = MagicMock()
|
||||
child_chunk.content = "child"
|
||||
child_chunk.index_node_id = "id"
|
||||
child_chunk.index_node_hash = "h"
|
||||
child_chunk.document_id = "doc-1"
|
||||
child_chunk.dataset_id = "dataset-1"
|
||||
|
||||
VectorService.create_child_chunk_vector(child_chunk, dataset)
|
||||
vector_cls.assert_not_called()
|
||||
|
||||
|
||||
def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality")
|
||||
|
||||
new_chunk = MagicMock()
|
||||
new_chunk.content = "n"
|
||||
new_chunk.index_node_id = "nid"
|
||||
new_chunk.index_node_hash = "nh"
|
||||
new_chunk.document_id = "d"
|
||||
new_chunk.dataset_id = "ds"
|
||||
|
||||
upd_chunk = MagicMock()
|
||||
upd_chunk.content = "u"
|
||||
upd_chunk.index_node_id = "uid"
|
||||
upd_chunk.index_node_hash = "uh"
|
||||
upd_chunk.document_id = "d"
|
||||
upd_chunk.dataset_id = "ds"
|
||||
|
||||
del_chunk = MagicMock()
|
||||
del_chunk.index_node_id = "did"
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
|
||||
VectorService.update_child_chunk_vector([new_chunk], [upd_chunk], [del_chunk], dataset)
|
||||
|
||||
vector_instance.delete_by_ids.assert_called_once_with(["uid", "did"])
|
||||
vector_instance.add_texts.assert_called_once()
|
||||
docs = vector_instance.add_texts.call_args.args[0]
|
||||
assert len(docs) == 2
|
||||
|
||||
|
||||
def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy")
|
||||
vector_cls = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
VectorService.update_child_chunk_vector([], [], [], dataset)
|
||||
vector_cls.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset()
|
||||
child_chunk = MagicMock()
|
||||
child_chunk.index_node_id = "cid"
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
|
||||
VectorService.delete_child_chunk_vector(child_chunk, dataset)
|
||||
vector_instance.delete_by_ids.assert_called_once_with(["cid"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_multimodel_vector (missing coverage in previous suites)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="economy", is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}])
|
||||
|
||||
vector_cls = MagicMock()
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[])
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["a"], dataset=dataset)
|
||||
vector_cls.assert_not_called()
|
||||
db_mock.session.query.assert_not_called()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}])
|
||||
|
||||
vector_cls = MagicMock()
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[])
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["b", "a"], dataset=dataset)
|
||||
vector_cls.assert_not_called()
|
||||
db_mock.session.query.assert_not_called()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}])
|
||||
|
||||
vector_instance = MagicMock(name="vector_instance")
|
||||
vector_cls = MagicMock(return_value=vector_instance)
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[])
|
||||
|
||||
monkeypatch.setattr(vector_service_module, "Vector", vector_cls)
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=[], dataset=dataset)
|
||||
|
||||
vector_cls.assert_called_once_with(dataset=dataset)
|
||||
vector_instance.delete_by_ids.assert_called_once_with(["old-1", "old-2"])
|
||||
db_mock.session.query.assert_called_once_with(vector_service_module.SegmentAttachmentBinding)
|
||||
db_mock.session.commit.assert_called_once()
|
||||
db_mock.session.add_all.assert_not_called()
|
||||
vector_instance.add_texts.assert_not_called()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[])
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["new-1"], dataset=dataset)
|
||||
|
||||
db_mock.session.commit.assert_called_once()
|
||||
db_mock.session.add_all.assert_not_called()
|
||||
vector_instance.add_texts.assert_not_called()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")])
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
|
||||
binding_ctor = MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
monkeypatch.setattr(vector_service_module, "SegmentAttachmentBinding", binding_ctor)
|
||||
|
||||
logger_mock = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "logger", logger_mock)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1", "missing"], dataset=dataset)
|
||||
|
||||
logger_mock.warning.assert_called_once()
|
||||
db_mock.session.add_all.assert_called_once()
|
||||
bindings = db_mock.session.add_all.call_args.args[0]
|
||||
assert len(bindings) == 1
|
||||
assert bindings[0]["attachment_id"] == "file-1"
|
||||
|
||||
vector_instance.add_texts.assert_called_once()
|
||||
documents = vector_instance.add_texts.call_args.args[0]
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "img.png"
|
||||
assert documents[0].metadata["doc_id"] == "file-1"
|
||||
db_mock.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False)
|
||||
segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")])
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
monkeypatch.setattr(
|
||||
vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
)
|
||||
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset)
|
||||
|
||||
vector_instance.delete_by_ids.assert_not_called()
|
||||
vector_instance.add_texts.assert_not_called()
|
||||
db_mock.session.add_all.assert_called_once()
|
||||
db_mock.session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True)
|
||||
segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}])
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")])
|
||||
db_mock.session.commit.side_effect = RuntimeError("boom")
|
||||
monkeypatch.setattr(vector_service_module, "db", db_mock)
|
||||
monkeypatch.setattr(
|
||||
vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs)
|
||||
)
|
||||
|
||||
logger_mock = MagicMock()
|
||||
monkeypatch.setattr(vector_service_module, "logger", logger_mock)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset)
|
||||
|
||||
logger_mock.exception.assert_called_once()
|
||||
db_mock.session.rollback.assert_called_once()
|
||||
718
api/tests/unit_tests/services/test_website_service.py
Normal file
718
api/tests/unit_tests/services/test_website_service.py
Normal file
@ -0,0 +1,718 @@
|
||||
"""Unit tests for services.website_service.
|
||||
|
||||
Focuses on provider dispatching, argument validation, and provider-specific branches
|
||||
without making any real network/storage/redis calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import services.website_service as website_service_module
|
||||
from services.website_service import (
|
||||
CrawlOptions,
|
||||
WebsiteCrawlApiRequest,
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
WebsiteService,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _DummyHttpxResponse:
|
||||
payload: dict[str, Any]
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
return self.payload
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def stub_current_user(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
website_service_module,
|
||||
"current_user",
|
||||
type("User", (), {"current_tenant_id": "tenant-1"})(),
|
||||
)
|
||||
|
||||
|
||||
def test_crawl_options_include_exclude_paths() -> None:
|
||||
options = CrawlOptions(includes="a,b", excludes="x,y")
|
||||
assert options.get_include_paths() == ["a", "b"]
|
||||
assert options.get_exclude_paths() == ["x", "y"]
|
||||
|
||||
empty = CrawlOptions(includes=None, excludes=None)
|
||||
assert empty.get_include_paths() == []
|
||||
assert empty.get_exclude_paths() == []
|
||||
|
||||
|
||||
def test_website_crawl_api_request_from_args_valid_and_to_crawl_request() -> None:
|
||||
args = {
|
||||
"provider": "firecrawl",
|
||||
"url": "https://example.com",
|
||||
"options": {
|
||||
"limit": 2,
|
||||
"crawl_sub_pages": True,
|
||||
"only_main_content": True,
|
||||
"includes": "a,b",
|
||||
"excludes": "x",
|
||||
"prompt": "hi",
|
||||
"max_depth": 3,
|
||||
"use_sitemap": False,
|
||||
},
|
||||
}
|
||||
|
||||
api_req = WebsiteCrawlApiRequest.from_args(args)
|
||||
crawl_req = api_req.to_crawl_request()
|
||||
|
||||
assert crawl_req.provider == "firecrawl"
|
||||
assert crawl_req.url == "https://example.com"
|
||||
assert crawl_req.options.limit == 2
|
||||
assert crawl_req.options.crawl_sub_pages is True
|
||||
assert crawl_req.options.only_main_content is True
|
||||
assert crawl_req.options.get_include_paths() == ["a", "b"]
|
||||
assert crawl_req.options.get_exclude_paths() == ["x"]
|
||||
assert crawl_req.options.prompt == "hi"
|
||||
assert crawl_req.options.max_depth == 3
|
||||
assert crawl_req.options.use_sitemap is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("args", "missing_msg"),
|
||||
[
|
||||
({}, "Provider is required"),
|
||||
({"provider": "firecrawl"}, "URL is required"),
|
||||
({"provider": "firecrawl", "url": "https://example.com"}, "Options are required"),
|
||||
],
|
||||
)
|
||||
def test_website_crawl_api_request_from_args_requires_fields(args: dict, missing_msg: str) -> None:
|
||||
with pytest.raises(ValueError, match=missing_msg):
|
||||
WebsiteCrawlApiRequest.from_args(args)
|
||||
|
||||
|
||||
def test_website_crawl_status_api_request_from_args_requires_fields() -> None:
|
||||
with pytest.raises(ValueError, match="Provider is required"):
|
||||
WebsiteCrawlStatusApiRequest.from_args({}, job_id="job-1")
|
||||
|
||||
with pytest.raises(ValueError, match="Job ID is required"):
|
||||
WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="")
|
||||
|
||||
req = WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="job-1")
|
||||
assert req.provider == "firecrawl"
|
||||
assert req.job_id == "job-1"
|
||||
|
||||
|
||||
def test_get_credentials_and_config_selects_plugin_id_and_key_firecrawl(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service_instance = MagicMock(name="DatasourceProviderService-instance")
|
||||
service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k", "base_url": "b"}
|
||||
monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance))
|
||||
|
||||
api_key, config = WebsiteService._get_credentials_and_config("tenant-1", "firecrawl")
|
||||
assert api_key == "k"
|
||||
assert config["base_url"] == "b"
|
||||
|
||||
service_instance.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider="firecrawl",
|
||||
plugin_id="langgenius/firecrawl_datasource",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "plugin_id"),
|
||||
[
|
||||
("watercrawl", "langgenius/watercrawl_datasource"),
|
||||
("jinareader", "langgenius/jina_datasource"),
|
||||
],
|
||||
)
|
||||
def test_get_credentials_and_config_selects_plugin_id_and_key_api_key(
|
||||
monkeypatch: pytest.MonkeyPatch, provider: str, plugin_id: str
|
||||
) -> None:
|
||||
service_instance = MagicMock(name="DatasourceProviderService-instance")
|
||||
service_instance.get_datasource_credentials.return_value = {"api_key": "enc-key", "base_url": "b"}
|
||||
monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance))
|
||||
|
||||
api_key, config = WebsiteService._get_credentials_and_config("tenant-1", provider)
|
||||
assert api_key == "enc-key"
|
||||
assert config["base_url"] == "b"
|
||||
|
||||
service_instance.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
|
||||
|
||||
def test_get_credentials_and_config_rejects_invalid_provider() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService._get_credentials_and_config("tenant-1", "unknown")
|
||||
|
||||
|
||||
def test_get_credentials_and_config_hits_unreachable_guard_branch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FlakyProvider:
|
||||
def __init__(self) -> None:
|
||||
self._eq_calls = 0
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return 1
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if other == "firecrawl":
|
||||
self._eq_calls += 1
|
||||
return self._eq_calls == 1
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "FlakyProvider()"
|
||||
|
||||
service_instance = MagicMock(name="DatasourceProviderService-instance")
|
||||
service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k"}
|
||||
monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService._get_credentials_and_config("tenant-1", FlakyProvider()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_get_decrypted_api_key_requires_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", MagicMock())
|
||||
with pytest.raises(ValueError, match="API key not found in configuration"):
|
||||
WebsiteService._get_decrypted_api_key("tenant-1", {})
|
||||
|
||||
|
||||
def test_get_decrypted_api_key_decrypts(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
decrypt_mock = MagicMock(return_value="plain")
|
||||
monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", decrypt_mock)
|
||||
|
||||
assert WebsiteService._get_decrypted_api_key("tenant-1", {"api_key": "enc"}) == "plain"
|
||||
decrypt_mock.assert_called_once_with(tenant_id="tenant-1", token="enc")
|
||||
|
||||
|
||||
def test_document_create_args_validate_wraps_error_message() -> None:
|
||||
with pytest.raises(ValueError, match=r"^Invalid arguments: Provider is required$"):
|
||||
WebsiteService.document_create_args_validate({})
|
||||
|
||||
|
||||
def test_crawl_url_dispatches_by_provider(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api_request = WebsiteCrawlApiRequest(provider="firecrawl", url="https://example.com", options={"limit": 1})
|
||||
crawl_request = api_request.to_crawl_request()
|
||||
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"})))
|
||||
firecrawl_mock = MagicMock(return_value={"status": "active", "job_id": "j1"})
|
||||
monkeypatch.setattr(WebsiteService, "_crawl_with_firecrawl", firecrawl_mock)
|
||||
|
||||
result = WebsiteService.crawl_url(api_request)
|
||||
|
||||
assert result == {"status": "active", "job_id": "j1"}
|
||||
firecrawl_mock.assert_called_once()
|
||||
assert firecrawl_mock.call_args.kwargs["request"] == crawl_request
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "method_name"),
|
||||
[
|
||||
("watercrawl", "_crawl_with_watercrawl"),
|
||||
("jinareader", "_crawl_with_jinareader"),
|
||||
],
|
||||
)
|
||||
def test_crawl_url_dispatches_other_providers(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None:
|
||||
api_request = WebsiteCrawlApiRequest(provider=provider, url="https://example.com", options={"limit": 1})
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"})))
|
||||
|
||||
impl_mock = MagicMock(return_value={"status": "active"})
|
||||
monkeypatch.setattr(WebsiteService, method_name, impl_mock)
|
||||
|
||||
assert WebsiteService.crawl_url(api_request) == {"status": "active"}
|
||||
impl_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_crawl_url_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api_request = WebsiteCrawlApiRequest(provider="bad", url="https://example.com", options={"limit": 1})
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {})))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService.crawl_url(api_request)
|
||||
|
||||
|
||||
def test_crawl_with_firecrawl_builds_params_single_page_and_sets_redis(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock(name="FirecrawlApp-instance")
|
||||
firecrawl_instance.crawl_url.return_value = "job-1"
|
||||
firecrawl_cls = MagicMock(return_value=firecrawl_instance)
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", firecrawl_cls)
|
||||
|
||||
redis_mock = MagicMock()
|
||||
monkeypatch.setattr(website_service_module, "redis_client", redis_mock)
|
||||
|
||||
fixed_now = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
with patch.object(website_service_module.datetime, "datetime") as datetime_mock:
|
||||
datetime_mock.now.return_value = fixed_now
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="firecrawl", url="https://example.com", options={"limit": 5}
|
||||
).to_crawl_request()
|
||||
req.options.crawl_sub_pages = False
|
||||
req.options.only_main_content = True
|
||||
|
||||
result = WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": "b"})
|
||||
|
||||
assert result == {"status": "active", "job_id": "job-1"}
|
||||
|
||||
firecrawl_cls.assert_called_once_with(api_key="k", base_url="b")
|
||||
firecrawl_instance.crawl_url.assert_called_once()
|
||||
_, params = firecrawl_instance.crawl_url.call_args.args
|
||||
assert params["limit"] == 1
|
||||
assert params["includePaths"] == []
|
||||
assert params["excludePaths"] == []
|
||||
assert params["scrapeOptions"] == {"onlyMainContent": True}
|
||||
|
||||
redis_mock.setex.assert_called_once()
|
||||
key, ttl, value = redis_mock.setex.call_args.args
|
||||
assert key == "website_crawl_job-1"
|
||||
assert ttl == 3600
|
||||
assert float(value) == pytest.approx(fixed_now.timestamp(), rel=0, abs=1e-6)
|
||||
|
||||
|
||||
def test_crawl_with_firecrawl_builds_params_multi_page_including_prompt(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock(name="FirecrawlApp-instance")
|
||||
firecrawl_instance.crawl_url.return_value = "job-2"
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
monkeypatch.setattr(website_service_module, "redis_client", MagicMock())
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="firecrawl",
|
||||
url="https://example.com",
|
||||
options={
|
||||
"crawl_sub_pages": True,
|
||||
"limit": 3,
|
||||
"only_main_content": False,
|
||||
"includes": "a,b",
|
||||
"excludes": "x",
|
||||
"prompt": "use this",
|
||||
},
|
||||
).to_crawl_request()
|
||||
|
||||
WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": None})
|
||||
_, params = firecrawl_instance.crawl_url.call_args.args
|
||||
assert params["includePaths"] == ["a", "b"]
|
||||
assert params["excludePaths"] == ["x"]
|
||||
assert params["limit"] == 3
|
||||
assert params["scrapeOptions"] == {"onlyMainContent": False}
|
||||
assert params["prompt"] == "use this"
|
||||
|
||||
|
||||
def test_crawl_with_watercrawl_passes_options_dict(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.crawl_url.return_value = {"status": "active", "job_id": "w1"}
|
||||
provider_cls = MagicMock(return_value=provider_instance)
|
||||
monkeypatch.setattr(website_service_module, "WaterCrawlProvider", provider_cls)
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="watercrawl",
|
||||
url="https://example.com",
|
||||
options={
|
||||
"limit": 2,
|
||||
"crawl_sub_pages": True,
|
||||
"only_main_content": True,
|
||||
"includes": "a",
|
||||
"excludes": None,
|
||||
"max_depth": 5,
|
||||
"use_sitemap": False,
|
||||
},
|
||||
).to_crawl_request()
|
||||
|
||||
result = WebsiteService._crawl_with_watercrawl(request=req, api_key="k", config={"base_url": "b"})
|
||||
assert result == {"status": "active", "job_id": "w1"}
|
||||
|
||||
provider_cls.assert_called_once_with(api_key="k", base_url="b")
|
||||
provider_instance.crawl_url.assert_called_once_with(
|
||||
url="https://example.com",
|
||||
options={
|
||||
"limit": 2,
|
||||
"crawl_sub_pages": True,
|
||||
"only_main_content": True,
|
||||
"includes": "a",
|
||||
"excludes": None,
|
||||
"max_depth": 5,
|
||||
"use_sitemap": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
get_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"title": "t"}}))
|
||||
monkeypatch.setattr(website_service_module.httpx, "get", get_mock)
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False}
|
||||
).to_crawl_request()
|
||||
req.options.crawl_sub_pages = False
|
||||
|
||||
result = WebsiteService._crawl_with_jinareader(request=req, api_key="k")
|
||||
assert result == {"status": "active", "data": {"title": "t"}}
|
||||
get_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500})))
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False}
|
||||
).to_crawl_request()
|
||||
req.options.crawl_sub_pages = False
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to crawl:"):
|
||||
WebsiteService._crawl_with_jinareader(request=req, api_key="k")
|
||||
|
||||
|
||||
def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
post_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"taskId": "t1"}}))
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader",
|
||||
url="https://example.com",
|
||||
options={"crawl_sub_pages": True, "limit": 5, "use_sitemap": True},
|
||||
).to_crawl_request()
|
||||
req.options.crawl_sub_pages = True
|
||||
|
||||
result = WebsiteService._crawl_with_jinareader(request=req, api_key="k")
|
||||
assert result == {"status": "active", "job_id": "t1"}
|
||||
post_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_crawl_with_jinareader_multi_page_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
website_service_module.httpx, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400}))
|
||||
)
|
||||
req = WebsiteCrawlApiRequest(
|
||||
provider="jinareader",
|
||||
url="https://example.com",
|
||||
options={"crawl_sub_pages": True, "limit": 2, "use_sitemap": False},
|
||||
).to_crawl_request()
|
||||
req.options.crawl_sub_pages = True
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to crawl$"):
|
||||
WebsiteService._crawl_with_jinareader(request=req, api_key="k")
|
||||
|
||||
|
||||
def test_get_crawl_status_dispatches(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"})))
|
||||
firecrawl_status = MagicMock(return_value={"status": "active"})
|
||||
monkeypatch.setattr(WebsiteService, "_get_firecrawl_status", firecrawl_status)
|
||||
|
||||
result = WebsiteService.get_crawl_status("job-1", "firecrawl")
|
||||
assert result == {"status": "active"}
|
||||
firecrawl_status.assert_called_once_with("job-1", "k", {"base_url": "b"})
|
||||
|
||||
watercrawl_status = MagicMock(return_value={"status": "active", "job_id": "w"})
|
||||
monkeypatch.setattr(WebsiteService, "_get_watercrawl_status", watercrawl_status)
|
||||
assert WebsiteService.get_crawl_status("job-2", "watercrawl") == {"status": "active", "job_id": "w"}
|
||||
watercrawl_status.assert_called_once_with("job-2", "k", {"base_url": "b"})
|
||||
|
||||
jinareader_status = MagicMock(return_value={"status": "active", "job_id": "j"})
|
||||
monkeypatch.setattr(WebsiteService, "_get_jinareader_status", jinareader_status)
|
||||
assert WebsiteService.get_crawl_status("job-3", "jinareader") == {"status": "active", "job_id": "j"}
|
||||
jinareader_status.assert_called_once_with("job-3", "k")
|
||||
|
||||
|
||||
def test_get_crawl_status_typed_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {})))
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService.get_crawl_status_typed(WebsiteCrawlStatusApiRequest(provider="bad", job_id="j"))
|
||||
|
||||
|
||||
def test_get_firecrawl_status_adds_time_consuming_when_completed_and_cached(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 2, "current": 2, "data": []}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
redis_mock = MagicMock()
|
||||
redis_mock.get.return_value = b"100.0"
|
||||
monkeypatch.setattr(website_service_module, "redis_client", redis_mock)
|
||||
|
||||
with patch.object(website_service_module.datetime, "datetime") as datetime_mock:
|
||||
datetime_mock.now.return_value = datetime.fromtimestamp(105.0, tz=UTC)
|
||||
result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": "b"})
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["time_consuming"] == "5.00"
|
||||
redis_mock.delete.assert_called_once_with("website_crawl_job-1")
|
||||
|
||||
|
||||
def test_get_firecrawl_status_completed_without_cache_does_not_add_time(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed"}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
redis_mock = MagicMock()
|
||||
redis_mock.get.return_value = None
|
||||
monkeypatch.setattr(website_service_module, "redis_client", redis_mock)
|
||||
|
||||
result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": None})
|
||||
assert result["status"] == "completed"
|
||||
assert "time_consuming" not in result
|
||||
redis_mock.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_get_watercrawl_status_delegates(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.get_crawl_status.return_value = {"status": "active", "job_id": "w1"}
|
||||
monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance))
|
||||
|
||||
assert WebsiteService._get_watercrawl_status("job-1", "k", {"base_url": "b"}) == {
|
||||
"status": "active",
|
||||
"job_id": "w1",
|
||||
}
|
||||
provider_instance.get_crawl_status.assert_called_once_with("job-1")
|
||||
|
||||
|
||||
def test_get_jinareader_status_active(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
post_mock = MagicMock(
|
||||
return_value=_DummyHttpxResponse(
|
||||
{
|
||||
"data": {
|
||||
"status": "active",
|
||||
"urls": ["a", "b"],
|
||||
"processed": {"a": {}},
|
||||
"failed": {"b": {}},
|
||||
"duration": 3000,
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
result = WebsiteService._get_jinareader_status("job-1", "k")
|
||||
assert result["status"] == "active"
|
||||
assert result["total"] == 2
|
||||
assert result["current"] == 2
|
||||
assert result["time_consuming"] == 3.0
|
||||
assert result["data"] == []
|
||||
post_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_get_jinareader_status_completed_formats_processed_items(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
status_payload = {
|
||||
"data": {
|
||||
"status": "completed",
|
||||
"urls": ["u1"],
|
||||
"processed": {"u1": {}},
|
||||
"failed": {},
|
||||
"duration": 1000,
|
||||
}
|
||||
}
|
||||
processed_payload = {
|
||||
"data": {
|
||||
"processed": {
|
||||
"u1": {
|
||||
"data": {
|
||||
"title": "t",
|
||||
"url": "u1",
|
||||
"description": "d",
|
||||
"content": "md",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)])
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
result = WebsiteService._get_jinareader_status("job-1", "k")
|
||||
assert result["status"] == "completed"
|
||||
assert result["data"] == [{"title": "t", "source_url": "u1", "description": "d", "markdown": "md"}]
|
||||
assert post_mock.call_count == 2
|
||||
|
||||
|
||||
def test_get_crawl_url_data_dispatches_invalid_provider() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService.get_crawl_url_data("job-1", "bad", "https://example.com", "tenant-1")
|
||||
|
||||
|
||||
def test_get_crawl_url_data_hits_invalid_provider_branch_when_credentials_stubbed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {})))
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService.get_crawl_url_data("job-1", object(), "u", "tenant-1") # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "method_name"),
|
||||
[
|
||||
("firecrawl", "_get_firecrawl_url_data"),
|
||||
("watercrawl", "_get_watercrawl_url_data"),
|
||||
("jinareader", "_get_jinareader_url_data"),
|
||||
],
|
||||
)
|
||||
def test_get_crawl_url_data_dispatches(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None:
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"})))
|
||||
impl_mock = MagicMock(return_value={"ok": True})
|
||||
monkeypatch.setattr(WebsiteService, method_name, impl_mock)
|
||||
|
||||
result = WebsiteService.get_crawl_url_data("job-1", provider, "u", "tenant-1")
|
||||
assert result == {"ok": True}
|
||||
impl_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_get_firecrawl_url_data_reads_from_storage_when_present(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
stored_list = [{"source_url": "https://example.com", "title": "t"}]
|
||||
stored = json.dumps(stored_list).encode("utf-8")
|
||||
|
||||
storage_mock = MagicMock()
|
||||
storage_mock.exists.return_value = True
|
||||
storage_mock.load_once.return_value = stored
|
||||
monkeypatch.setattr(website_service_module, "storage", storage_mock)
|
||||
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock())
|
||||
|
||||
result = WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"})
|
||||
assert result == {"source_url": "https://example.com", "title": "t"}
|
||||
assert result is not stored_list[0]
|
||||
|
||||
|
||||
def test_get_firecrawl_url_data_returns_none_when_storage_empty(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
storage_mock = MagicMock()
|
||||
storage_mock.exists.return_value = True
|
||||
storage_mock.load_once.return_value = b""
|
||||
monkeypatch.setattr(website_service_module, "storage", storage_mock)
|
||||
|
||||
assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {}) is None
|
||||
|
||||
|
||||
def test_get_firecrawl_url_data_raises_when_job_not_completed(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
storage_mock = MagicMock()
|
||||
storage_mock.exists.return_value = False
|
||||
monkeypatch.setattr(website_service_module, "storage", storage_mock)
|
||||
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "active"}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
with pytest.raises(ValueError, match="Crawl job is not completed"):
|
||||
WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": None})
|
||||
|
||||
|
||||
def test_get_firecrawl_url_data_returns_none_when_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
storage_mock = MagicMock()
|
||||
storage_mock.exists.return_value = False
|
||||
monkeypatch.setattr(website_service_module, "storage", storage_mock)
|
||||
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "data": [{"source_url": "x"}]}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"}) is None
|
||||
|
||||
|
||||
def test_get_watercrawl_url_data_delegates(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.get_crawl_url_data.return_value = {"source_url": "u"}
|
||||
monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance))
|
||||
|
||||
result = WebsiteService._get_watercrawl_url_data("job-1", "u", "k", {"base_url": "b"})
|
||||
assert result == {"source_url": "u"}
|
||||
provider_instance.get_crawl_url_data.assert_called_once_with("job-1", "u")
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
website_service_module.httpx,
|
||||
"get",
|
||||
MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"url": "u"}})),
|
||||
)
|
||||
assert WebsiteService._get_jinareader_url_data("", "u", "k") == {"url": "u"}
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_without_job_id_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500})))
|
||||
with pytest.raises(ValueError, match="Failed to crawl$"):
|
||||
WebsiteService._get_jinareader_url_data("", "u", "k")
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}}
|
||||
processed_payload = {"data": {"processed": {"u1": {"data": {"url": "u", "title": "t"}}}}}
|
||||
|
||||
post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)])
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") == {"url": "u", "title": "t"}
|
||||
assert post_mock.call_count == 2
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_with_job_id_not_completed_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
post_mock = MagicMock(return_value=_DummyHttpxResponse({"data": {"status": "active"}}))
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Crawl job is no\s*t completed"):
|
||||
WebsiteService._get_jinareader_url_data("job-1", "u", "k")
|
||||
|
||||
|
||||
def test_get_jinareader_url_data_with_job_id_completed_but_not_found_returns_none(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}}
|
||||
processed_payload = {"data": {"processed": {"u1": {"data": {"url": "other"}}}}}
|
||||
|
||||
post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)])
|
||||
monkeypatch.setattr(website_service_module.httpx, "post", post_mock)
|
||||
|
||||
assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") is None
|
||||
|
||||
|
||||
def test_get_scrape_url_data_dispatches_and_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"})))
|
||||
|
||||
scrape_mock = MagicMock(return_value={"data": "x"})
|
||||
monkeypatch.setattr(WebsiteService, "_scrape_with_firecrawl", scrape_mock)
|
||||
assert WebsiteService.get_scrape_url_data("firecrawl", "u", "tenant-1", True) == {"data": "x"}
|
||||
scrape_mock.assert_called_once()
|
||||
|
||||
watercrawl_mock = MagicMock(return_value={"data": "y"})
|
||||
monkeypatch.setattr(WebsiteService, "_scrape_with_watercrawl", watercrawl_mock)
|
||||
assert WebsiteService.get_scrape_url_data("watercrawl", "u", "tenant-1", False) == {"data": "y"}
|
||||
watercrawl_mock.assert_called_once()
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid provider"):
|
||||
WebsiteService.get_scrape_url_data("jinareader", "u", "tenant-1", True)
|
||||
|
||||
|
||||
def test_scrape_with_firecrawl_calls_app(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
firecrawl_instance = MagicMock()
|
||||
firecrawl_instance.scrape_url.return_value = {"markdown": "m"}
|
||||
monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance))
|
||||
|
||||
result = WebsiteService._scrape_with_firecrawl(
|
||||
request=website_service_module.ScrapeRequest(
|
||||
provider="firecrawl",
|
||||
url="u",
|
||||
tenant_id="tenant-1",
|
||||
only_main_content=True,
|
||||
),
|
||||
api_key="k",
|
||||
config={"base_url": "b"},
|
||||
)
|
||||
assert result == {"markdown": "m"}
|
||||
firecrawl_instance.scrape_url.assert_called_once_with(url="u", params={"onlyMainContent": True})
|
||||
|
||||
|
||||
def test_scrape_with_watercrawl_calls_provider(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
provider_instance = MagicMock()
|
||||
provider_instance.scrape_url.return_value = {"markdown": "m"}
|
||||
monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance))
|
||||
|
||||
result = WebsiteService._scrape_with_watercrawl(
|
||||
request=website_service_module.ScrapeRequest(
|
||||
provider="watercrawl",
|
||||
url="u",
|
||||
tenant_id="tenant-1",
|
||||
only_main_content=False,
|
||||
),
|
||||
api_key="k",
|
||||
config={"base_url": "b"},
|
||||
)
|
||||
assert result == {"markdown": "m"}
|
||||
provider_instance.scrape_url.assert_called_once_with("u")
|
||||
@ -311,7 +311,9 @@ class TestWorkflowService:
|
||||
mock_workflow.conversation_variables = []
|
||||
|
||||
# Mock node config
|
||||
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
|
||||
mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python(
|
||||
{"id": "node-1", "data": {"type": NodeType.LLM.value}}
|
||||
)
|
||||
mock_workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
|
||||
# Mock class methods
|
||||
@ -376,7 +378,9 @@ class TestWorkflowService:
|
||||
mock_workflow.tenant_id = "tenant-1"
|
||||
mock_workflow.environment_variables = []
|
||||
mock_workflow.conversation_variables = []
|
||||
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
|
||||
mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python(
|
||||
{"id": "node-1", "data": {"type": NodeType.LLM.value}}
|
||||
)
|
||||
mock_workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock())
|
||||
|
||||
@ -7,17 +7,15 @@ const { mockReactMarkdownWrapper } = vi.hoisted(() => ({
|
||||
mockReactMarkdownWrapper: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('../react-markdown-wrapper', () => ({
|
||||
ReactMarkdownWrapper: () => null,
|
||||
}))
|
||||
|
||||
vi.mock('next/dynamic', () => ({
|
||||
default: (loader: () => Promise<unknown>) => {
|
||||
void loader()
|
||||
return (props: { latexContent: string }) => {
|
||||
default: () => {
|
||||
const MockStreamdownWrapper = (props: { latexContent: string }) => {
|
||||
mockReactMarkdownWrapper(props)
|
||||
return <div data-testid="react-markdown-wrapper">{props.latexContent}</div>
|
||||
}
|
||||
|
||||
MockStreamdownWrapper.displayName = 'MockStreamdownWrapper'
|
||||
return MockStreamdownWrapper
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user