mirror of
https://github.com/langgenius/dify.git
synced 2026-03-16 20:37:42 +08:00
test: added test cases for core.workflow module (#33126)
This commit is contained in:
@ -150,8 +150,9 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.TraceQueueManager", DummyTraceQueueManager)
|
||||
@ -1124,8 +1125,9 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
@ -1202,8 +1204,9 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
|
||||
@ -240,12 +240,12 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
def test_iteration_and_loop_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = (
|
||||
lambda **kwargs: "iter_start"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: (
|
||||
"iter_start"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = (
|
||||
lambda **kwargs: "iter_done"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: (
|
||||
"iter_done"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start"
|
||||
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
|
||||
|
||||
@ -144,8 +144,9 @@ class TestWorkflowAppGeneratorGenerate:
|
||||
"_DummyTraceQueueManager",
|
||||
(TraceQueueManager,),
|
||||
{
|
||||
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
|
||||
or setattr(self, "user_id", user_id)
|
||||
"__init__": lambda self, app_id=None, user_id=None: (
|
||||
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
|
||||
)
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
|
||||
@ -1,82 +1,603 @@
|
||||
from __future__ import annotations
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, sentinel
|
||||
|
||||
from typing import Any
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
|
||||
from core.workflow import node_factory
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import NodeType, SystemVariableKey
|
||||
from dify_graph.nodes.code.entities import CodeLanguage
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
|
||||
|
||||
def _build_factory(graph_config: dict[str, Any]) -> DifyNodeFactory:
|
||||
graph_init_params = build_test_graph_init_params(
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
|
||||
assert config["id"] == node_id
|
||||
assert isinstance(config["data"], BaseNodeData)
|
||||
assert config["data"].type == node_type
|
||||
assert config["data"].version == version
|
||||
|
||||
|
||||
class TestFetchMemory:
|
||||
@pytest.mark.parametrize(
|
||||
("conversation_id", "memory_config"),
|
||||
[
|
||||
(None, object()),
|
||||
("conversation-id", None),
|
||||
],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
),
|
||||
start_at=0.0,
|
||||
def test_returns_none_when_memory_or_conversation_is_missing(self, conversation_id, memory_config):
|
||||
result = node_factory.fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id="app-id",
|
||||
node_data_memory=memory_config,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_conversation_does_not_exist(self, monkeypatch):
|
||||
class FakeSelect:
|
||||
def where(self, *_args):
|
||||
return self
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args):
|
||||
return False
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
|
||||
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
|
||||
monkeypatch.setattr(node_factory, "Session", FakeSession)
|
||||
|
||||
result = node_factory.fetch_memory(
|
||||
conversation_id="conversation-id",
|
||||
app_id="app-id",
|
||||
node_data_memory=object(),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_builds_token_buffer_memory_for_existing_conversation(self, monkeypatch):
|
||||
conversation = sentinel.conversation
|
||||
memory = sentinel.memory
|
||||
|
||||
class FakeSelect:
|
||||
def where(self, *_args):
|
||||
return self
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args):
|
||||
return False
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return conversation
|
||||
|
||||
token_buffer_memory = MagicMock(return_value=memory)
|
||||
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
|
||||
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
|
||||
monkeypatch.setattr(node_factory, "Session", FakeSession)
|
||||
monkeypatch.setattr(node_factory, "TokenBufferMemory", token_buffer_memory)
|
||||
|
||||
result = node_factory.fetch_memory(
|
||||
conversation_id="conversation-id",
|
||||
app_id="app-id",
|
||||
node_data_memory=object(),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is memory
|
||||
token_buffer_memory.assert_called_once_with(
|
||||
conversation=conversation,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
|
||||
class TestDefaultWorkflowCodeExecutor:
|
||||
def test_execute_delegates_to_code_executor(self, monkeypatch):
|
||||
executor = node_factory.DefaultWorkflowCodeExecutor()
|
||||
execute_workflow_code_template = MagicMock(return_value={"answer": "ok"})
|
||||
monkeypatch.setattr(
|
||||
node_factory.CodeExecutor,
|
||||
"execute_workflow_code_template",
|
||||
execute_workflow_code_template,
|
||||
)
|
||||
|
||||
result = executor.execute(
|
||||
language=CodeLanguage.PYTHON3,
|
||||
code="print('ok')",
|
||||
inputs={"name": "workflow"},
|
||||
)
|
||||
|
||||
assert result == {"answer": "ok"}
|
||||
execute_workflow_code_template.assert_called_once_with(
|
||||
language=CodeLanguage.PYTHON3,
|
||||
code="print('ok')",
|
||||
inputs={"name": "workflow"},
|
||||
)
|
||||
|
||||
def test_is_execution_error_checks_code_execution_error_type(self):
|
||||
executor = node_factory.DefaultWorkflowCodeExecutor()
|
||||
|
||||
assert executor.is_execution_error(node_factory.CodeExecutionError("boom")) is True
|
||||
assert executor.is_execution_error(RuntimeError("boom")) is False
|
||||
|
||||
|
||||
class TestDifyNodeFactoryInit:
|
||||
def test_init_builds_default_dependencies(self):
|
||||
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
||||
graph_runtime_state = sentinel.graph_runtime_state
|
||||
dify_context = SimpleNamespace(tenant_id="tenant-id")
|
||||
template_renderer = sentinel.template_renderer
|
||||
rag_retrieval = sentinel.rag_retrieval
|
||||
unstructured_api_config = sentinel.unstructured_api_config
|
||||
http_request_config = sentinel.http_request_config
|
||||
credentials_provider = sentinel.credentials_provider
|
||||
model_factory = sentinel.model_factory
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
node_factory.DifyNodeFactory,
|
||||
"_resolve_dify_context",
|
||||
return_value=dify_context,
|
||||
) as resolve_dify_context,
|
||||
patch.object(
|
||||
node_factory,
|
||||
"CodeExecutorJinja2TemplateRenderer",
|
||||
return_value=template_renderer,
|
||||
) as renderer_factory,
|
||||
patch.object(node_factory, "DatasetRetrieval", return_value=rag_retrieval),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"UnstructuredApiConfig",
|
||||
return_value=unstructured_api_config,
|
||||
),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"build_http_request_config",
|
||||
return_value=http_request_config,
|
||||
),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"build_dify_model_access",
|
||||
return_value=(credentials_provider, model_factory),
|
||||
) as build_dify_model_access,
|
||||
):
|
||||
factory = node_factory.DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
|
||||
build_dify_model_access.assert_called_once_with("tenant-id")
|
||||
renderer_factory.assert_called_once()
|
||||
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
|
||||
assert factory.graph_init_params is graph_init_params
|
||||
assert factory.graph_runtime_state is graph_runtime_state
|
||||
assert factory._dify_context is dify_context
|
||||
assert factory._template_renderer is template_renderer
|
||||
assert factory._rag_retrieval is rag_retrieval
|
||||
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
|
||||
assert factory._http_request_config is http_request_config
|
||||
assert factory._llm_credentials_provider is credentials_provider
|
||||
assert factory._llm_model_factory is model_factory
|
||||
|
||||
|
||||
class TestDifyNodeFactoryResolveContext:
|
||||
def test_requires_reserved_context_key(self):
|
||||
with pytest.raises(ValueError, match=DIFY_RUN_CONTEXT_KEY):
|
||||
node_factory.DifyNodeFactory._resolve_dify_context({})
|
||||
|
||||
def test_returns_existing_dify_context(self):
|
||||
dify_context = DifyRunContext(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
result = node_factory.DifyNodeFactory._resolve_dify_context({DIFY_RUN_CONTEXT_KEY: dify_context})
|
||||
|
||||
assert result is dify_context
|
||||
|
||||
def test_validates_mapping_context(self):
|
||||
raw_context = {
|
||||
DIFY_RUN_CONTEXT_KEY: {
|
||||
"tenant_id": "tenant-id",
|
||||
"app_id": "app-id",
|
||||
"user_id": "user-id",
|
||||
"user_from": UserFrom.ACCOUNT,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
}
|
||||
|
||||
result = node_factory.DifyNodeFactory._resolve_dify_context(raw_context)
|
||||
|
||||
assert isinstance(result, DifyRunContext)
|
||||
assert result.tenant_id == "tenant-id"
|
||||
|
||||
|
||||
class TestDifyNodeFactoryCreateNode:
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
factory = object.__new__(node_factory.DifyNodeFactory)
|
||||
factory.graph_init_params = sentinel.graph_init_params
|
||||
factory.graph_runtime_state = sentinel.graph_runtime_state
|
||||
factory._dify_context = SimpleNamespace(tenant_id="tenant-id", app_id="app-id")
|
||||
factory._code_executor = sentinel.code_executor
|
||||
factory._code_limits = sentinel.code_limits
|
||||
factory._template_renderer = sentinel.template_renderer
|
||||
factory._template_transform_max_output_length = 2048
|
||||
factory._http_request_http_client = sentinel.http_client
|
||||
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
|
||||
factory._http_request_file_manager = sentinel.file_manager
|
||||
factory._rag_retrieval = sentinel.rag_retrieval
|
||||
factory._document_extractor_unstructured_api_config = sentinel.unstructured_api_config
|
||||
factory._http_request_config = sentinel.http_request_config
|
||||
factory._llm_credentials_provider = sentinel.credentials_provider
|
||||
factory._llm_model_factory = sentinel.model_factory
|
||||
return factory
|
||||
|
||||
def test_rejects_unknown_node_type(self, factory):
|
||||
with pytest.raises(ValueError, match="Input should be"):
|
||||
factory.create_node({"id": "node-id", "data": {"type": "missing"}})
|
||||
|
||||
def test_rejects_missing_class_mapping(self, monkeypatch, factory):
|
||||
monkeypatch.setattr(node_factory, "NODE_TYPE_CLASSES_MAPPING", {})
|
||||
|
||||
with pytest.raises(ValueError, match="No class mapping found for node type: start"):
|
||||
factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}})
|
||||
|
||||
def test_rejects_missing_latest_class(self, monkeypatch, factory):
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.START: {node_factory.LATEST_VERSION: None}},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No latest version class found for node type: start"):
|
||||
factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}})
|
||||
|
||||
def test_uses_version_specific_class_when_available(self, monkeypatch, factory):
|
||||
matched_node = sentinel.matched_node
|
||||
latest_node_class = MagicMock(return_value=sentinel.latest_node)
|
||||
matched_node_class = MagicMock(return_value=matched_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{
|
||||
NodeType.START: {
|
||||
node_factory.LATEST_VERSION: latest_node_class,
|
||||
"9": matched_node_class,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}})
|
||||
|
||||
assert result is matched_node
|
||||
matched_node_class.assert_called_once()
|
||||
kwargs = matched_node_class.call_args.kwargs
|
||||
assert kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
latest_node_class.assert_not_called()
|
||||
|
||||
def test_falls_back_to_latest_class_when_version_specific_mapping_is_missing(self, monkeypatch, factory):
|
||||
latest_node = sentinel.latest_node
|
||||
latest_node_class = MagicMock(return_value=latest_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.START: {node_factory.LATEST_VERSION: latest_node_class}},
|
||||
)
|
||||
|
||||
result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}})
|
||||
|
||||
assert result is latest_node
|
||||
latest_node_class.assert_called_once()
|
||||
kwargs = latest_node_class.call_args.kwargs
|
||||
assert kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9")
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name"),
|
||||
[
|
||||
(NodeType.CODE, "CodeNode"),
|
||||
(NodeType.TEMPLATE_TRANSFORM, "TemplateTransformNode"),
|
||||
(NodeType.HTTP_REQUEST, "HttpRequestNode"),
|
||||
(NodeType.HUMAN_INPUT, "HumanInputNode"),
|
||||
(NodeType.KNOWLEDGE_INDEX, "KnowledgeIndexNode"),
|
||||
(NodeType.DATASOURCE, "DatasourceNode"),
|
||||
(NodeType.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"),
|
||||
(NodeType.DOCUMENT_EXTRACTOR, "DocumentExtractorNode"),
|
||||
],
|
||||
)
|
||||
return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
|
||||
def test_creates_specialized_nodes(self, monkeypatch, factory, node_type, constructor_name):
|
||||
created_node = object()
|
||||
constructor = MagicMock(name=constructor_name, return_value=created_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{node_type: {node_factory.LATEST_VERSION: constructor}},
|
||||
)
|
||||
|
||||
if constructor_name == "HumanInputNode":
|
||||
form_repository = sentinel.form_repository
|
||||
form_repository_impl = MagicMock(return_value=form_repository)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"HumanInputFormRepositoryImpl",
|
||||
form_repository_impl,
|
||||
)
|
||||
elif constructor_name == "KnowledgeIndexNode":
|
||||
index_processor = sentinel.index_processor
|
||||
summary_index = sentinel.summary_index
|
||||
monkeypatch.setattr(node_factory, "IndexProcessor", MagicMock(return_value=index_processor))
|
||||
monkeypatch.setattr(node_factory, "SummaryIndex", MagicMock(return_value=summary_index))
|
||||
|
||||
node_config = {"id": "node-id", "data": {"type": node_type.value}}
|
||||
result = factory.create_node(node_config)
|
||||
|
||||
assert result is created_node
|
||||
kwargs = constructor.call_args.kwargs
|
||||
assert kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
assert kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
|
||||
if constructor_name == "CodeNode":
|
||||
assert kwargs["code_executor"] is sentinel.code_executor
|
||||
assert kwargs["code_limits"] is sentinel.code_limits
|
||||
elif constructor_name == "TemplateTransformNode":
|
||||
assert kwargs["template_renderer"] is sentinel.template_renderer
|
||||
assert kwargs["max_output_length"] == 2048
|
||||
elif constructor_name == "HttpRequestNode":
|
||||
assert kwargs["http_request_config"] is sentinel.http_request_config
|
||||
assert kwargs["http_client"] is sentinel.http_client
|
||||
assert kwargs["tool_file_manager_factory"] is sentinel.tool_file_manager_factory
|
||||
assert kwargs["file_manager"] is sentinel.file_manager
|
||||
elif constructor_name == "HumanInputNode":
|
||||
assert kwargs["form_repository"] is form_repository
|
||||
form_repository_impl.assert_called_once_with(tenant_id="tenant-id")
|
||||
elif constructor_name == "KnowledgeIndexNode":
|
||||
assert kwargs["index_processor"] is index_processor
|
||||
assert kwargs["summary_index_service"] is summary_index
|
||||
elif constructor_name == "DatasourceNode":
|
||||
assert kwargs["datasource_manager"] is node_factory.DatasourceManager
|
||||
elif constructor_name == "KnowledgeRetrievalNode":
|
||||
assert kwargs["rag_retrieval"] is sentinel.rag_retrieval
|
||||
elif constructor_name == "DocumentExtractorNode":
|
||||
assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config
|
||||
assert kwargs["http_client"] is sentinel.http_client
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||
[
|
||||
(NodeType.LLM, "LLMNode", {"http_client": sentinel.http_client}),
|
||||
(NodeType.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
|
||||
(NodeType.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
|
||||
],
|
||||
)
|
||||
def test_creates_model_backed_nodes(
|
||||
self,
|
||||
monkeypatch,
|
||||
factory,
|
||||
node_type,
|
||||
constructor_name,
|
||||
expected_extra_kwargs,
|
||||
):
|
||||
created_node = object()
|
||||
constructor = MagicMock(name=constructor_name, return_value=created_node)
|
||||
monkeypatch.setattr(
|
||||
node_factory,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{node_type: {node_factory.LATEST_VERSION: constructor}},
|
||||
)
|
||||
llm_init_kwargs = {
|
||||
"credentials_provider": sentinel.credentials_provider,
|
||||
"model_factory": sentinel.model_factory,
|
||||
"model_instance": sentinel.model_instance,
|
||||
"memory": sentinel.memory,
|
||||
**expected_extra_kwargs,
|
||||
}
|
||||
build_llm_init_kwargs = MagicMock(return_value=llm_init_kwargs)
|
||||
factory._build_llm_compatible_node_init_kwargs = build_llm_init_kwargs
|
||||
|
||||
node_config = {"id": "node-id", "data": {"type": node_type.value}}
|
||||
result = factory.create_node(node_config)
|
||||
|
||||
assert result is created_node
|
||||
build_llm_init_kwargs.assert_called_once()
|
||||
helper_kwargs = build_llm_init_kwargs.call_args.kwargs
|
||||
assert helper_kwargs["node_class"] is constructor
|
||||
assert isinstance(helper_kwargs["node_data"], BaseNodeData)
|
||||
assert helper_kwargs["node_data"].type == node_type
|
||||
assert helper_kwargs["include_http_client"] is (node_type != NodeType.PARAMETER_EXTRACTOR)
|
||||
|
||||
constructor_kwargs = constructor.call_args.kwargs
|
||||
assert constructor_kwargs["id"] == "node-id"
|
||||
_assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type)
|
||||
assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params
|
||||
assert constructor_kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
|
||||
assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider
|
||||
assert constructor_kwargs["model_factory"] is sentinel.model_factory
|
||||
assert constructor_kwargs["model_instance"] is sentinel.model_instance
|
||||
assert constructor_kwargs["memory"] is sentinel.memory
|
||||
for key, value in expected_extra_kwargs.items():
|
||||
assert constructor_kwargs[key] is value
|
||||
|
||||
|
||||
def test_create_node_uses_declared_node_data_type_for_llm_validation(monkeypatch):
|
||||
class _FactoryLLMNodeData(LLMNodeData):
|
||||
pass
|
||||
class TestDifyNodeFactoryModelInstance:
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
factory = object.__new__(node_factory.DifyNodeFactory)
|
||||
factory._llm_credentials_provider = MagicMock()
|
||||
factory._llm_model_factory = MagicMock()
|
||||
return factory
|
||||
|
||||
llm_node_config = {
|
||||
"id": "llm-node",
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "LLM",
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o-mini",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"prompt_template": [],
|
||||
"context": {
|
||||
"enabled": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
graph_config = {"nodes": [llm_node_config], "edges": []}
|
||||
factory = _build_factory(graph_config)
|
||||
captured: dict[str, object] = {}
|
||||
@pytest.fixture
|
||||
def llm_model_setup(self, factory):
|
||||
def _configure(
|
||||
*,
|
||||
completion_params=None,
|
||||
has_provider_model=True,
|
||||
model_schema=sentinel.model_schema,
|
||||
):
|
||||
credentials = {"api_key": "secret"}
|
||||
node_data_model = SimpleNamespace(
|
||||
provider="provider",
|
||||
name="model",
|
||||
mode="chat",
|
||||
completion_params=completion_params or {},
|
||||
)
|
||||
node_data = SimpleNamespace(model=node_data_model)
|
||||
provider_model = MagicMock() if has_provider_model else None
|
||||
provider_model_bundle = SimpleNamespace(
|
||||
configuration=SimpleNamespace(get_provider_model=MagicMock(return_value=provider_model))
|
||||
)
|
||||
model_type_instance = MagicMock()
|
||||
model_type_instance.get_model_schema.return_value = model_schema
|
||||
model_instance = SimpleNamespace(
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
model_type_instance=model_type_instance,
|
||||
provider=None,
|
||||
model_name=None,
|
||||
credentials=None,
|
||||
parameters=None,
|
||||
stop=None,
|
||||
)
|
||||
factory._llm_credentials_provider.fetch.return_value = credentials
|
||||
factory._llm_model_factory.init_model_instance.return_value = model_instance
|
||||
return SimpleNamespace(
|
||||
node_data=node_data,
|
||||
credentials=credentials,
|
||||
provider_model=provider_model,
|
||||
model_type_instance=model_type_instance,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(LLMNode, "_node_data_type", _FactoryLLMNodeData)
|
||||
return _configure
|
||||
|
||||
def _capture_model_instance(self: DifyNodeFactory, node_data: object) -> ModelInstance:
|
||||
captured["node_data"] = node_data
|
||||
return object() # type: ignore[return-value]
|
||||
def test_requires_llm_mode(self, factory):
|
||||
node_data = SimpleNamespace(
|
||||
model=SimpleNamespace(
|
||||
provider="provider",
|
||||
name="model",
|
||||
mode="",
|
||||
completion_params={},
|
||||
)
|
||||
)
|
||||
|
||||
def _capture_memory(
|
||||
self: DifyNodeFactory,
|
||||
*,
|
||||
node_data: object,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
captured["memory_node_data"] = node_data
|
||||
with pytest.raises(node_factory.LLMModeRequiredError, match="LLM mode is required"):
|
||||
factory._build_model_instance_for_llm_node(node_data)
|
||||
|
||||
monkeypatch.setattr(DifyNodeFactory, "_build_model_instance_for_llm_node", _capture_model_instance)
|
||||
monkeypatch.setattr(DifyNodeFactory, "_build_memory_for_llm_node", _capture_memory)
|
||||
def test_raises_when_provider_model_is_missing(self, factory, llm_model_setup):
|
||||
setup = llm_model_setup(has_provider_model=False)
|
||||
|
||||
node = factory.create_node(llm_node_config)
|
||||
with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"):
|
||||
factory._build_model_instance_for_llm_node(setup.node_data)
|
||||
|
||||
assert isinstance(captured["node_data"], _FactoryLLMNodeData)
|
||||
assert isinstance(captured["memory_node_data"], _FactoryLLMNodeData)
|
||||
assert isinstance(node.node_data, _FactoryLLMNodeData)
|
||||
def test_raises_when_model_schema_is_missing(self, factory, llm_model_setup):
|
||||
setup = llm_model_setup(model_schema=None)
|
||||
|
||||
with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"):
|
||||
factory._build_model_instance_for_llm_node(setup.node_data)
|
||||
|
||||
setup.provider_model.raise_for_status.assert_called_once()
|
||||
|
||||
def test_builds_model_instance_and_normalizes_stop_tokens(self, factory, llm_model_setup):
|
||||
setup = llm_model_setup(
|
||||
completion_params={"temperature": 0.3, "stop": "not-a-list"},
|
||||
model_schema={"schema": "value"},
|
||||
)
|
||||
|
||||
result = factory._build_model_instance_for_llm_node(setup.node_data)
|
||||
|
||||
assert result is setup.model_instance
|
||||
assert result.provider == "provider"
|
||||
assert result.model_name == "model"
|
||||
assert result.credentials == setup.credentials
|
||||
assert result.parameters == {"temperature": 0.3}
|
||||
assert result.stop == ()
|
||||
assert result.model_type_instance is setup.model_type_instance
|
||||
setup.provider_model.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
class TestDifyNodeFactoryMemory:
|
||||
@pytest.fixture
|
||||
def factory(self):
|
||||
factory = object.__new__(node_factory.DifyNodeFactory)
|
||||
factory._dify_context = SimpleNamespace(app_id="app-id")
|
||||
factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock())
|
||||
return factory
|
||||
|
||||
def test_returns_none_when_memory_is_not_configured(self, factory):
|
||||
result = factory._build_memory_for_llm_node(
|
||||
node_data=SimpleNamespace(memory=None),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
factory.graph_runtime_state.variable_pool.get.assert_not_called()
|
||||
|
||||
def test_uses_string_segment_conversation_id(self, monkeypatch, factory):
|
||||
memory_config = sentinel.memory_config
|
||||
factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="conversation-id")
|
||||
fetch_memory = MagicMock(return_value=sentinel.memory)
|
||||
monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory)
|
||||
|
||||
result = factory._build_memory_for_llm_node(
|
||||
node_data=SimpleNamespace(memory=memory_config),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is sentinel.memory
|
||||
factory.graph_runtime_state.variable_pool.get.assert_called_once_with(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
fetch_memory.assert_called_once_with(
|
||||
conversation_id="conversation-id",
|
||||
app_id="app-id",
|
||||
node_data_memory=memory_config,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
def test_ignores_non_string_segment_conversation_ids(self, monkeypatch, factory):
|
||||
memory_config = sentinel.memory_config
|
||||
factory.graph_runtime_state.variable_pool.get.return_value = sentinel.segment
|
||||
fetch_memory = MagicMock(return_value=sentinel.memory)
|
||||
monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory)
|
||||
|
||||
result = factory._build_memory_for_llm_node(
|
||||
node_data=SimpleNamespace(memory=memory_config),
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
assert result is sentinel.memory
|
||||
fetch_memory.assert_called_once_with(
|
||||
conversation_id=None,
|
||||
app_id="app-id",
|
||||
node_data_memory=memory_config,
|
||||
model_instance=sentinel.model_instance,
|
||||
)
|
||||
|
||||
@ -0,0 +1,656 @@
|
||||
from collections import UserString
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.workflow import workflow_entry
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.graph_events import GraphRunFailedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.runtime import ChildGraphNotFoundError
|
||||
|
||||
|
||||
def _build_typed_node_config(node_type: NodeType):
|
||||
return NodeConfigDictAdapter.validate_python({"id": "node-id", "data": {"type": node_type}})
|
||||
|
||||
|
||||
class TestWorkflowChildEngineBuilder:
|
||||
@pytest.mark.parametrize(
|
||||
("graph_config", "node_id", "expected"),
|
||||
[
|
||||
({"nodes": [{"id": "root"}]}, "root", True),
|
||||
({"nodes": [{"id": "root"}]}, "other", False),
|
||||
({"nodes": "invalid"}, "root", None),
|
||||
({"nodes": ["invalid"]}, "root", None),
|
||||
],
|
||||
)
|
||||
def test_has_node_id(self, graph_config, node_id, expected):
|
||||
result = workflow_entry._WorkflowChildEngineBuilder._has_node_id(graph_config, node_id)
|
||||
|
||||
assert result is expected
|
||||
|
||||
def test_build_child_engine_raises_when_root_node_is_missing(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
|
||||
with patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory):
|
||||
with pytest.raises(ChildGraphNotFoundError, match="child graph root node 'missing' not found"):
|
||||
builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
graph_config={"nodes": []},
|
||||
root_node_id="missing",
|
||||
)
|
||||
|
||||
def test_build_child_engine_constructs_graph_engine_and_layers(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
child_graph = sentinel.child_graph
|
||||
child_engine = MagicMock()
|
||||
quota_layer = sentinel.quota_layer
|
||||
additional_layers = [sentinel.layer_one, sentinel.layer_two]
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory) as dify_node_factory,
|
||||
patch.object(workflow_entry.Graph, "init", return_value=child_graph) as graph_init,
|
||||
patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls,
|
||||
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
|
||||
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=quota_layer),
|
||||
):
|
||||
result = builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
graph_config={"nodes": [{"id": "root"}]},
|
||||
root_node_id="root",
|
||||
layers=additional_layers,
|
||||
)
|
||||
|
||||
assert result is child_engine
|
||||
dify_node_factory.assert_called_once_with(
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
graph_init.assert_called_once_with(
|
||||
graph_config={"nodes": [{"id": "root"}]},
|
||||
node_factory=sentinel.factory,
|
||||
root_node_id="root",
|
||||
)
|
||||
graph_engine_cls.assert_called_once_with(
|
||||
workflow_id="workflow-id",
|
||||
graph=child_graph,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
command_channel=sentinel.command_channel,
|
||||
config=sentinel.graph_engine_config,
|
||||
child_engine_builder=builder,
|
||||
)
|
||||
assert child_engine.layer.call_args_list == [
|
||||
((quota_layer,), {}),
|
||||
((sentinel.layer_one,), {}),
|
||||
((sentinel.layer_two,), {}),
|
||||
]
|
||||
|
||||
|
||||
class TestWorkflowEntryInit:
|
||||
def test_rejects_call_depth_above_limit(self):
|
||||
call_depth = workflow_entry.dify_config.WORKFLOW_CALL_MAX_DEPTH + 1
|
||||
|
||||
with pytest.raises(ValueError, match="Max workflow call depth"):
|
||||
workflow_entry.WorkflowEntry(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
graph=sentinel.graph,
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=call_depth,
|
||||
variable_pool=sentinel.variable_pool,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
|
||||
def test_applies_debug_and_observability_layers(self):
|
||||
graph_engine = MagicMock()
|
||||
debug_layer = sentinel.debug_layer
|
||||
execution_limits_layer = sentinel.execution_limits_layer
|
||||
llm_quota_layer = sentinel.llm_quota_layer
|
||||
observability_layer = sentinel.observability_layer
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry.dify_config, "DEBUG", True),
|
||||
patch.object(workflow_entry.dify_config, "ENABLE_OTEL", False),
|
||||
patch.object(workflow_entry, "is_instrument_flag_enabled", return_value=True),
|
||||
patch.object(workflow_entry, "GraphEngine", return_value=graph_engine) as graph_engine_cls,
|
||||
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
|
||||
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
|
||||
patch.object(workflow_entry, "DebugLoggingLayer", return_value=debug_layer) as debug_logging_layer,
|
||||
patch.object(
|
||||
workflow_entry,
|
||||
"ExecutionLimitsLayer",
|
||||
return_value=execution_limits_layer,
|
||||
) as execution_limits_layer_cls,
|
||||
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer),
|
||||
patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer),
|
||||
):
|
||||
entry = workflow_entry.WorkflowEntry(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id-123456",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
graph=sentinel.graph,
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
variable_pool=sentinel.variable_pool,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
command_channel=None,
|
||||
)
|
||||
|
||||
assert entry.command_channel is sentinel.command_channel
|
||||
graph_engine_cls.assert_called_once_with(
|
||||
workflow_id="workflow-id-123456",
|
||||
graph=sentinel.graph,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
command_channel=sentinel.command_channel,
|
||||
config=sentinel.graph_engine_config,
|
||||
child_engine_builder=entry._child_engine_builder,
|
||||
)
|
||||
debug_logging_layer.assert_called_once_with(
|
||||
level="DEBUG",
|
||||
include_inputs=True,
|
||||
include_outputs=True,
|
||||
include_process_data=False,
|
||||
logger_name="GraphEngine.Debug.workflow",
|
||||
)
|
||||
execution_limits_layer_cls.assert_called_once_with(
|
||||
max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
)
|
||||
assert graph_engine.layer.call_args_list == [
|
||||
((debug_layer,), {}),
|
||||
((execution_limits_layer,), {}),
|
||||
((llm_quota_layer,), {}),
|
||||
((observability_layer,), {}),
|
||||
]
|
||||
|
||||
|
||||
class TestWorkflowEntryRun:
|
||||
def test_run_swallows_generate_task_stopped_errors(self):
|
||||
entry = object.__new__(workflow_entry.WorkflowEntry)
|
||||
entry.graph_engine = MagicMock()
|
||||
entry.graph_engine.run.side_effect = GenerateTaskStoppedError()
|
||||
|
||||
assert list(entry.run()) == []
|
||||
|
||||
def test_run_emits_failed_event_for_unexpected_errors(self):
|
||||
entry = object.__new__(workflow_entry.WorkflowEntry)
|
||||
entry.graph_engine = MagicMock()
|
||||
entry.graph_engine.run.side_effect = RuntimeError("boom")
|
||||
|
||||
events = list(entry.run())
|
||||
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], GraphRunFailedEvent)
|
||||
assert events[0].error == "boom"
|
||||
|
||||
|
||||
class TestWorkflowEntrySingleStepRun:
|
||||
def test_uses_empty_mapping_when_selector_extraction_is_not_implemented(self):
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "fake"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
) as mapping_user_inputs_to_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
return_value=iter(["event"]),
|
||||
),
|
||||
):
|
||||
dify_node_factory.return_value.create_node.return_value = FakeNode()
|
||||
workflow = SimpleNamespace(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
id="workflow-id",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START),
|
||||
)
|
||||
|
||||
node, generator = workflow_entry.WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id="node-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
)
|
||||
|
||||
assert node.id == "node-id"
|
||||
assert list(generator) == ["event"]
|
||||
load_into_variable_pool.assert_called_once_with(
|
||||
variable_loader=workflow_entry.DUMMY_VARIABLE_LOADER,
|
||||
variable_pool=sentinel.variable_pool,
|
||||
variable_mapping={},
|
||||
user_inputs={"question": "hello"},
|
||||
)
|
||||
mapping_user_inputs_to_variable_pool.assert_called_once_with(
|
||||
variable_mapping={},
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_skips_user_input_mapping_for_datasource_nodes(self):
|
||||
class FakeDatasourceNode:
|
||||
id = "node-id"
|
||||
node_type = "datasource"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {"question": ["node", "question"]}
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
) as mapping_user_inputs_to_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
return_value=iter(["event"]),
|
||||
),
|
||||
):
|
||||
dify_node_factory.return_value.create_node.return_value = FakeDatasourceNode()
|
||||
workflow = SimpleNamespace(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
id="workflow-id",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.DATASOURCE),
|
||||
)
|
||||
|
||||
node, generator = workflow_entry.WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id="node-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
)
|
||||
|
||||
assert node.id == "node-id"
|
||||
assert list(generator) == ["event"]
|
||||
load_into_variable_pool.assert_called_once()
|
||||
mapping_user_inputs_to_variable_pool.assert_not_called()
|
||||
|
||||
def test_wraps_traced_node_run_failures(self):
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "fake"
|
||||
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool"),
|
||||
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
):
|
||||
dify_node_factory.return_value.create_node.return_value = FakeNode()
|
||||
workflow = SimpleNamespace(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
id="workflow-id",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START),
|
||||
)
|
||||
|
||||
with pytest.raises(WorkflowNodeRunFailedError):
|
||||
workflow_entry.WorkflowEntry.single_step_run(
|
||||
workflow=workflow,
|
||||
node_id="node-id",
|
||||
user_id="user-id",
|
||||
user_inputs={},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowEntryHelpers:
|
||||
def test_create_single_node_graph_builds_start_edge(self):
|
||||
graph = workflow_entry.WorkflowEntry._create_single_node_graph(
|
||||
node_id="target-node",
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR},
|
||||
node_width=320,
|
||||
node_height=180,
|
||||
)
|
||||
|
||||
assert graph["nodes"][0]["id"] == "start"
|
||||
assert graph["nodes"][1]["id"] == "target-node"
|
||||
assert graph["nodes"][1]["width"] == 320
|
||||
assert graph["nodes"][1]["height"] == 180
|
||||
assert graph["edges"] == [
|
||||
{
|
||||
"source": "start",
|
||||
"target": "target-node",
|
||||
"sourceHandle": "source",
|
||||
"targetHandle": "target",
|
||||
}
|
||||
]
|
||||
|
||||
def test_run_free_node_rejects_unsupported_types(self):
|
||||
with pytest.raises(ValueError, match="Node type start not supported"):
|
||||
workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.START.value},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
def test_run_free_node_rejects_missing_node_class(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
workflow_entry,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.PARAMETER_EXTRACTOR: {"1": None}},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Node class not found for node type parameter-extractor"):
|
||||
workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
def test_run_free_node_uses_empty_mapping_when_selector_extraction_is_not_implemented(self, monkeypatch):
|
||||
class FakeNodeClass:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "parameter-extractor"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
dify_node_factory = MagicMock()
|
||||
dify_node_factory.create_node.return_value = FakeNode()
|
||||
monkeypatch.setattr(
|
||||
workflow_entry,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables),
|
||||
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls,
|
||||
patch.object(
|
||||
workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params
|
||||
) as graph_init_params,
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(
|
||||
workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}
|
||||
) as build_dify_run_context,
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
) as mapping_user_inputs_to_variable_pool,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"_traced_node_run",
|
||||
return_value=iter(["event"]),
|
||||
),
|
||||
):
|
||||
node, generator = workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
)
|
||||
|
||||
assert node.id == "node-id"
|
||||
assert list(generator) == ["event"]
|
||||
variable_pool_cls.assert_called_once_with(
|
||||
system_variables=sentinel.system_variables,
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
build_dify_run_context.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
app_id="",
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_params.assert_called_once_with(
|
||||
workflow_id="",
|
||||
graph_config=workflow_entry.WorkflowEntry._create_single_node_graph(
|
||||
"node-id", {"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"}
|
||||
),
|
||||
run_context={"_dify": "context"},
|
||||
call_depth=0,
|
||||
)
|
||||
dify_node_factory_cls.assert_called_once_with(
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
mapping_user_inputs_to_variable_pool.assert_called_once_with(
|
||||
variable_mapping={},
|
||||
user_inputs={"question": "hello"},
|
||||
variable_pool=sentinel.variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_run_free_node_wraps_execution_failures(self, monkeypatch):
|
||||
class FakeNodeClass:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
class FakeNode:
|
||||
id = "node-id"
|
||||
title = "Node Title"
|
||||
node_type = "parameter-extractor"
|
||||
|
||||
@staticmethod
|
||||
def version():
|
||||
return "1"
|
||||
|
||||
dify_node_factory = MagicMock()
|
||||
dify_node_factory.create_node.return_value = FakeNode()
|
||||
monkeypatch.setattr(
|
||||
workflow_entry,
|
||||
"NODE_TYPE_CLASSES_MAPPING",
|
||||
{NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables),
|
||||
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool),
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory),
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(WorkflowNodeRunFailedError, match="Node Title run failed: boom"):
|
||||
workflow_entry.WorkflowEntry.run_free_node(
|
||||
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"},
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
user_id="user-id",
|
||||
user_inputs={"question": "hello"},
|
||||
)
|
||||
|
||||
def test_handle_special_values_serializes_nested_files(self):
|
||||
file = File(
|
||||
tenant_id="tenant-id",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.png",
|
||||
filename="image.png",
|
||||
extension=".png",
|
||||
)
|
||||
|
||||
result = workflow_entry.WorkflowEntry.handle_special_values({"file": file, "nested": {"files": [file]}})
|
||||
|
||||
assert result == {
|
||||
"file": file.to_dict(),
|
||||
"nested": {"files": [file.to_dict()]},
|
||||
}
|
||||
|
||||
def test_handle_special_values_returns_none_for_none(self):
|
||||
assert workflow_entry.WorkflowEntry._handle_special_values(None) is None
|
||||
|
||||
def test_handle_special_values_returns_scalar_as_is(self):
|
||||
assert workflow_entry.WorkflowEntry._handle_special_values("plain-text") == "plain-text"
|
||||
|
||||
|
||||
class TestMappingUserInputsBranches:
|
||||
def test_rejects_invalid_node_variable_key(self):
|
||||
class EmptySplitKey(UserString):
|
||||
def split(self, _sep=None):
|
||||
return []
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid node variable broken"):
|
||||
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping={EmptySplitKey("broken"): ["node", "input"]},
|
||||
user_inputs={},
|
||||
variable_pool=MagicMock(),
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_skips_none_user_input_when_variable_already_exists(self):
|
||||
variable_pool = MagicMock()
|
||||
variable_pool.get.return_value = None
|
||||
|
||||
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping={"node.input": ["target", "input"]},
|
||||
user_inputs={"node.input": None},
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
variable_pool.add.assert_not_called()
|
||||
|
||||
def test_merges_structured_output_values(self):
|
||||
variable_pool = MagicMock()
|
||||
variable_pool.get.side_effect = [
|
||||
None,
|
||||
SimpleNamespace(value={"existing": "value"}),
|
||||
]
|
||||
|
||||
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping={"node.answer": ["target", "structured_output", "answer"]},
|
||||
user_inputs={"node.answer": "new-value"},
|
||||
variable_pool=variable_pool,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
variable_pool.add.assert_called_once_with(
|
||||
["target", "structured_output"],
|
||||
{"existing": "value", "answer": "new-value"},
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowEntryTracing:
|
||||
def test_traced_node_run_reports_success(self):
|
||||
layer = MagicMock()
|
||||
|
||||
class FakeNode:
|
||||
def ensure_execution_id(self):
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
yield "event"
|
||||
|
||||
with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer):
|
||||
events = list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode()))
|
||||
|
||||
assert events == ["event"]
|
||||
layer.on_graph_start.assert_called_once_with()
|
||||
layer.on_node_run_start.assert_called_once()
|
||||
layer.on_node_run_end.assert_called_once_with(
|
||||
layer.on_node_run_start.call_args.args[0],
|
||||
None,
|
||||
)
|
||||
|
||||
def test_traced_node_run_reports_errors(self):
|
||||
layer = MagicMock()
|
||||
|
||||
class FakeNode:
|
||||
def ensure_execution_id(self):
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
raise RuntimeError("boom")
|
||||
yield
|
||||
|
||||
with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer):
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode()))
|
||||
|
||||
assert isinstance(layer.on_node_run_end.call_args.args[1], RuntimeError)
|
||||
@ -311,7 +311,9 @@ class TestWorkflowService:
|
||||
mock_workflow.conversation_variables = []
|
||||
|
||||
# Mock node config
|
||||
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
|
||||
mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python(
|
||||
{"id": "node-1", "data": {"type": NodeType.LLM.value}}
|
||||
)
|
||||
mock_workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
|
||||
# Mock class methods
|
||||
@ -376,7 +378,9 @@ class TestWorkflowService:
|
||||
mock_workflow.tenant_id = "tenant-1"
|
||||
mock_workflow.environment_variables = []
|
||||
mock_workflow.conversation_variables = []
|
||||
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
|
||||
mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python(
|
||||
{"id": "node-1", "data": {"type": NodeType.LLM.value}}
|
||||
)
|
||||
mock_workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock())
|
||||
|
||||
Reference in New Issue
Block a user