test: added test cases for core.workflow module (#33126)

This commit is contained in:
Rajat Agarwal
2026-03-12 13:05:25 +05:30
committed by GitHub
parent 157208ab1e
commit dc50e4c4f2
6 changed files with 1266 additions and 81 deletions

View File

@ -150,8 +150,9 @@ class TestAdvancedChatAppGeneratorInternals:
"_DummyTraceQueueManager",
(TraceQueueManager,),
{
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
or setattr(self, "user_id", user_id)
"__init__": lambda self, app_id=None, user_id=None: (
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
)
},
)
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.TraceQueueManager", DummyTraceQueueManager)
@ -1124,8 +1125,9 @@ class TestAdvancedChatAppGeneratorInternals:
"_DummyTraceQueueManager",
(TraceQueueManager,),
{
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
or setattr(self, "user_id", user_id)
"__init__": lambda self, app_id=None, user_id=None: (
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
)
},
)
monkeypatch.setattr(
@ -1202,8 +1204,9 @@ class TestAdvancedChatAppGeneratorInternals:
"_DummyTraceQueueManager",
(TraceQueueManager,),
{
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
or setattr(self, "user_id", user_id)
"__init__": lambda self, app_id=None, user_id=None: (
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
)
},
)
monkeypatch.setattr(

View File

@ -240,12 +240,12 @@ class TestAdvancedChatGenerateTaskPipeline:
def test_iteration_and_loop_handlers(self):
pipeline = _make_pipeline()
pipeline._workflow_run_id = "run-id"
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = (
lambda **kwargs: "iter_start"
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: (
"iter_start"
)
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next"
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = (
lambda **kwargs: "iter_done"
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: (
"iter_done"
)
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start"
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"

View File

@ -144,8 +144,9 @@ class TestWorkflowAppGeneratorGenerate:
"_DummyTraceQueueManager",
(TraceQueueManager,),
{
"__init__": lambda self, app_id=None, user_id=None: setattr(self, "app_id", app_id)
or setattr(self, "user_id", user_id)
"__init__": lambda self, app_id=None, user_id=None: (
setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id)
)
},
)
monkeypatch.setattr(

View File

@ -1,82 +1,603 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch, sentinel
from typing import Any
import pytest
from core.model_manager import ModelInstance
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.nodes.llm.entities import LLMNodeData
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from tests.workflow_test_utils import build_test_graph_init_params
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
from core.workflow import node_factory
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import NodeType, SystemVariableKey
from dify_graph.nodes.code.entities import CodeLanguage
from dify_graph.variables.segments import StringSegment
def _build_factory(graph_config: dict[str, Any]) -> DifyNodeFactory:
graph_init_params = build_test_graph_init_params(
workflow_id="workflow",
graph_config=graph_config,
tenant_id="tenant",
app_id="app",
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
assert config["id"] == node_id
assert isinstance(config["data"], BaseNodeData)
assert config["data"].type == node_type
assert config["data"].version == version
class TestFetchMemory:
@pytest.mark.parametrize(
("conversation_id", "memory_config"),
[
(None, object()),
("conversation-id", None),
],
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
environment_variables=[],
),
start_at=0.0,
def test_returns_none_when_memory_or_conversation_is_missing(self, conversation_id, memory_config):
result = node_factory.fetch_memory(
conversation_id=conversation_id,
app_id="app-id",
node_data_memory=memory_config,
model_instance=sentinel.model_instance,
)
assert result is None
def test_returns_none_when_conversation_does_not_exist(self, monkeypatch):
class FakeSelect:
def where(self, *_args):
return self
class FakeSession:
def __init__(self, *_args, **_kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *_args):
return False
def scalar(self, _stmt):
return None
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
monkeypatch.setattr(node_factory, "Session", FakeSession)
result = node_factory.fetch_memory(
conversation_id="conversation-id",
app_id="app-id",
node_data_memory=object(),
model_instance=sentinel.model_instance,
)
assert result is None
def test_builds_token_buffer_memory_for_existing_conversation(self, monkeypatch):
conversation = sentinel.conversation
memory = sentinel.memory
class FakeSelect:
def where(self, *_args):
return self
class FakeSession:
def __init__(self, *_args, **_kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *_args):
return False
def scalar(self, _stmt):
return conversation
token_buffer_memory = MagicMock(return_value=memory)
monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine))
monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect()))
monkeypatch.setattr(node_factory, "Session", FakeSession)
monkeypatch.setattr(node_factory, "TokenBufferMemory", token_buffer_memory)
result = node_factory.fetch_memory(
conversation_id="conversation-id",
app_id="app-id",
node_data_memory=object(),
model_instance=sentinel.model_instance,
)
assert result is memory
token_buffer_memory.assert_called_once_with(
conversation=conversation,
model_instance=sentinel.model_instance,
)
class TestDefaultWorkflowCodeExecutor:
def test_execute_delegates_to_code_executor(self, monkeypatch):
executor = node_factory.DefaultWorkflowCodeExecutor()
execute_workflow_code_template = MagicMock(return_value={"answer": "ok"})
monkeypatch.setattr(
node_factory.CodeExecutor,
"execute_workflow_code_template",
execute_workflow_code_template,
)
result = executor.execute(
language=CodeLanguage.PYTHON3,
code="print('ok')",
inputs={"name": "workflow"},
)
assert result == {"answer": "ok"}
execute_workflow_code_template.assert_called_once_with(
language=CodeLanguage.PYTHON3,
code="print('ok')",
inputs={"name": "workflow"},
)
def test_is_execution_error_checks_code_execution_error_type(self):
executor = node_factory.DefaultWorkflowCodeExecutor()
assert executor.is_execution_error(node_factory.CodeExecutionError("boom")) is True
assert executor.is_execution_error(RuntimeError("boom")) is False
class TestDifyNodeFactoryInit:
def test_init_builds_default_dependencies(self):
graph_init_params = SimpleNamespace(run_context={"context": "value"})
graph_runtime_state = sentinel.graph_runtime_state
dify_context = SimpleNamespace(tenant_id="tenant-id")
template_renderer = sentinel.template_renderer
rag_retrieval = sentinel.rag_retrieval
unstructured_api_config = sentinel.unstructured_api_config
http_request_config = sentinel.http_request_config
credentials_provider = sentinel.credentials_provider
model_factory = sentinel.model_factory
with (
patch.object(
node_factory.DifyNodeFactory,
"_resolve_dify_context",
return_value=dify_context,
) as resolve_dify_context,
patch.object(
node_factory,
"CodeExecutorJinja2TemplateRenderer",
return_value=template_renderer,
) as renderer_factory,
patch.object(node_factory, "DatasetRetrieval", return_value=rag_retrieval),
patch.object(
node_factory,
"UnstructuredApiConfig",
return_value=unstructured_api_config,
),
patch.object(
node_factory,
"build_http_request_config",
return_value=http_request_config,
),
patch.object(
node_factory,
"build_dify_model_access",
return_value=(credentials_provider, model_factory),
) as build_dify_model_access,
):
factory = node_factory.DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
build_dify_model_access.assert_called_once_with("tenant-id")
renderer_factory.assert_called_once()
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
assert factory.graph_init_params is graph_init_params
assert factory.graph_runtime_state is graph_runtime_state
assert factory._dify_context is dify_context
assert factory._template_renderer is template_renderer
assert factory._rag_retrieval is rag_retrieval
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
assert factory._http_request_config is http_request_config
assert factory._llm_credentials_provider is credentials_provider
assert factory._llm_model_factory is model_factory
class TestDifyNodeFactoryResolveContext:
def test_requires_reserved_context_key(self):
with pytest.raises(ValueError, match=DIFY_RUN_CONTEXT_KEY):
node_factory.DifyNodeFactory._resolve_dify_context({})
def test_returns_existing_dify_context(self):
dify_context = DifyRunContext(
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
result = node_factory.DifyNodeFactory._resolve_dify_context({DIFY_RUN_CONTEXT_KEY: dify_context})
assert result is dify_context
def test_validates_mapping_context(self):
raw_context = {
DIFY_RUN_CONTEXT_KEY: {
"tenant_id": "tenant-id",
"app_id": "app-id",
"user_id": "user-id",
"user_from": UserFrom.ACCOUNT,
"invoke_from": InvokeFrom.DEBUGGER,
}
}
result = node_factory.DifyNodeFactory._resolve_dify_context(raw_context)
assert isinstance(result, DifyRunContext)
assert result.tenant_id == "tenant-id"
class TestDifyNodeFactoryCreateNode:
@pytest.fixture
def factory(self):
factory = object.__new__(node_factory.DifyNodeFactory)
factory.graph_init_params = sentinel.graph_init_params
factory.graph_runtime_state = sentinel.graph_runtime_state
factory._dify_context = SimpleNamespace(tenant_id="tenant-id", app_id="app-id")
factory._code_executor = sentinel.code_executor
factory._code_limits = sentinel.code_limits
factory._template_renderer = sentinel.template_renderer
factory._template_transform_max_output_length = 2048
factory._http_request_http_client = sentinel.http_client
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
factory._http_request_file_manager = sentinel.file_manager
factory._rag_retrieval = sentinel.rag_retrieval
factory._document_extractor_unstructured_api_config = sentinel.unstructured_api_config
factory._http_request_config = sentinel.http_request_config
factory._llm_credentials_provider = sentinel.credentials_provider
factory._llm_model_factory = sentinel.model_factory
return factory
def test_rejects_unknown_node_type(self, factory):
with pytest.raises(ValueError, match="Input should be"):
factory.create_node({"id": "node-id", "data": {"type": "missing"}})
def test_rejects_missing_class_mapping(self, monkeypatch, factory):
monkeypatch.setattr(node_factory, "NODE_TYPE_CLASSES_MAPPING", {})
with pytest.raises(ValueError, match="No class mapping found for node type: start"):
factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}})
def test_rejects_missing_latest_class(self, monkeypatch, factory):
monkeypatch.setattr(
node_factory,
"NODE_TYPE_CLASSES_MAPPING",
{NodeType.START: {node_factory.LATEST_VERSION: None}},
)
with pytest.raises(ValueError, match="No latest version class found for node type: start"):
factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value}})
def test_uses_version_specific_class_when_available(self, monkeypatch, factory):
matched_node = sentinel.matched_node
latest_node_class = MagicMock(return_value=sentinel.latest_node)
matched_node_class = MagicMock(return_value=matched_node)
monkeypatch.setattr(
node_factory,
"NODE_TYPE_CLASSES_MAPPING",
{
NodeType.START: {
node_factory.LATEST_VERSION: latest_node_class,
"9": matched_node_class,
}
},
)
result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}})
assert result is matched_node
matched_node_class.assert_called_once()
kwargs = matched_node_class.call_args.kwargs
assert kwargs["id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9")
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
latest_node_class.assert_not_called()
def test_falls_back_to_latest_class_when_version_specific_mapping_is_missing(self, monkeypatch, factory):
latest_node = sentinel.latest_node
latest_node_class = MagicMock(return_value=latest_node)
monkeypatch.setattr(
node_factory,
"NODE_TYPE_CLASSES_MAPPING",
{NodeType.START: {node_factory.LATEST_VERSION: latest_node_class}},
)
result = factory.create_node({"id": "node-id", "data": {"type": NodeType.START.value, "version": "9"}})
assert result is latest_node
latest_node_class.assert_called_once()
kwargs = latest_node_class.call_args.kwargs
assert kwargs["id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=NodeType.START, version="9")
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
@pytest.mark.parametrize(
("node_type", "constructor_name"),
[
(NodeType.CODE, "CodeNode"),
(NodeType.TEMPLATE_TRANSFORM, "TemplateTransformNode"),
(NodeType.HTTP_REQUEST, "HttpRequestNode"),
(NodeType.HUMAN_INPUT, "HumanInputNode"),
(NodeType.KNOWLEDGE_INDEX, "KnowledgeIndexNode"),
(NodeType.DATASOURCE, "DatasourceNode"),
(NodeType.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"),
(NodeType.DOCUMENT_EXTRACTOR, "DocumentExtractorNode"),
],
)
return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
def test_creates_specialized_nodes(self, monkeypatch, factory, node_type, constructor_name):
created_node = object()
constructor = MagicMock(name=constructor_name, return_value=created_node)
monkeypatch.setattr(
node_factory,
"NODE_TYPE_CLASSES_MAPPING",
{node_type: {node_factory.LATEST_VERSION: constructor}},
)
if constructor_name == "HumanInputNode":
form_repository = sentinel.form_repository
form_repository_impl = MagicMock(return_value=form_repository)
monkeypatch.setattr(
node_factory,
"HumanInputFormRepositoryImpl",
form_repository_impl,
)
elif constructor_name == "KnowledgeIndexNode":
index_processor = sentinel.index_processor
summary_index = sentinel.summary_index
monkeypatch.setattr(node_factory, "IndexProcessor", MagicMock(return_value=index_processor))
monkeypatch.setattr(node_factory, "SummaryIndex", MagicMock(return_value=summary_index))
node_config = {"id": "node-id", "data": {"type": node_type.value}}
result = factory.create_node(node_config)
assert result is created_node
kwargs = constructor.call_args.kwargs
assert kwargs["id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type)
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
if constructor_name == "CodeNode":
assert kwargs["code_executor"] is sentinel.code_executor
assert kwargs["code_limits"] is sentinel.code_limits
elif constructor_name == "TemplateTransformNode":
assert kwargs["template_renderer"] is sentinel.template_renderer
assert kwargs["max_output_length"] == 2048
elif constructor_name == "HttpRequestNode":
assert kwargs["http_request_config"] is sentinel.http_request_config
assert kwargs["http_client"] is sentinel.http_client
assert kwargs["tool_file_manager_factory"] is sentinel.tool_file_manager_factory
assert kwargs["file_manager"] is sentinel.file_manager
elif constructor_name == "HumanInputNode":
assert kwargs["form_repository"] is form_repository
form_repository_impl.assert_called_once_with(tenant_id="tenant-id")
elif constructor_name == "KnowledgeIndexNode":
assert kwargs["index_processor"] is index_processor
assert kwargs["summary_index_service"] is summary_index
elif constructor_name == "DatasourceNode":
assert kwargs["datasource_manager"] is node_factory.DatasourceManager
elif constructor_name == "KnowledgeRetrievalNode":
assert kwargs["rag_retrieval"] is sentinel.rag_retrieval
elif constructor_name == "DocumentExtractorNode":
assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config
assert kwargs["http_client"] is sentinel.http_client
@pytest.mark.parametrize(
("node_type", "constructor_name", "expected_extra_kwargs"),
[
(NodeType.LLM, "LLMNode", {"http_client": sentinel.http_client}),
(NodeType.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
(NodeType.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
],
)
def test_creates_model_backed_nodes(
self,
monkeypatch,
factory,
node_type,
constructor_name,
expected_extra_kwargs,
):
created_node = object()
constructor = MagicMock(name=constructor_name, return_value=created_node)
monkeypatch.setattr(
node_factory,
"NODE_TYPE_CLASSES_MAPPING",
{node_type: {node_factory.LATEST_VERSION: constructor}},
)
llm_init_kwargs = {
"credentials_provider": sentinel.credentials_provider,
"model_factory": sentinel.model_factory,
"model_instance": sentinel.model_instance,
"memory": sentinel.memory,
**expected_extra_kwargs,
}
build_llm_init_kwargs = MagicMock(return_value=llm_init_kwargs)
factory._build_llm_compatible_node_init_kwargs = build_llm_init_kwargs
node_config = {"id": "node-id", "data": {"type": node_type.value}}
result = factory.create_node(node_config)
assert result is created_node
build_llm_init_kwargs.assert_called_once()
helper_kwargs = build_llm_init_kwargs.call_args.kwargs
assert helper_kwargs["node_class"] is constructor
assert isinstance(helper_kwargs["node_data"], BaseNodeData)
assert helper_kwargs["node_data"].type == node_type
assert helper_kwargs["include_http_client"] is (node_type != NodeType.PARAMETER_EXTRACTOR)
constructor_kwargs = constructor.call_args.kwargs
assert constructor_kwargs["id"] == "node-id"
_assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type)
assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params
assert constructor_kwargs["graph_runtime_state"] is sentinel.graph_runtime_state
assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider
assert constructor_kwargs["model_factory"] is sentinel.model_factory
assert constructor_kwargs["model_instance"] is sentinel.model_instance
assert constructor_kwargs["memory"] is sentinel.memory
for key, value in expected_extra_kwargs.items():
assert constructor_kwargs[key] is value
def test_create_node_uses_declared_node_data_type_for_llm_validation(monkeypatch):
class _FactoryLLMNodeData(LLMNodeData):
pass
class TestDifyNodeFactoryModelInstance:
@pytest.fixture
def factory(self):
factory = object.__new__(node_factory.DifyNodeFactory)
factory._llm_credentials_provider = MagicMock()
factory._llm_model_factory = MagicMock()
return factory
llm_node_config = {
"id": "llm-node",
"data": {
"type": "llm",
"title": "LLM",
"model": {
"provider": "openai",
"name": "gpt-4o-mini",
"mode": "chat",
"completion_params": {},
},
"prompt_template": [],
"context": {
"enabled": False,
},
},
}
graph_config = {"nodes": [llm_node_config], "edges": []}
factory = _build_factory(graph_config)
captured: dict[str, object] = {}
@pytest.fixture
def llm_model_setup(self, factory):
def _configure(
*,
completion_params=None,
has_provider_model=True,
model_schema=sentinel.model_schema,
):
credentials = {"api_key": "secret"}
node_data_model = SimpleNamespace(
provider="provider",
name="model",
mode="chat",
completion_params=completion_params or {},
)
node_data = SimpleNamespace(model=node_data_model)
provider_model = MagicMock() if has_provider_model else None
provider_model_bundle = SimpleNamespace(
configuration=SimpleNamespace(get_provider_model=MagicMock(return_value=provider_model))
)
model_type_instance = MagicMock()
model_type_instance.get_model_schema.return_value = model_schema
model_instance = SimpleNamespace(
provider_model_bundle=provider_model_bundle,
model_type_instance=model_type_instance,
provider=None,
model_name=None,
credentials=None,
parameters=None,
stop=None,
)
factory._llm_credentials_provider.fetch.return_value = credentials
factory._llm_model_factory.init_model_instance.return_value = model_instance
return SimpleNamespace(
node_data=node_data,
credentials=credentials,
provider_model=provider_model,
model_type_instance=model_type_instance,
model_instance=model_instance,
)
monkeypatch.setattr(LLMNode, "_node_data_type", _FactoryLLMNodeData)
return _configure
def _capture_model_instance(self: DifyNodeFactory, node_data: object) -> ModelInstance:
captured["node_data"] = node_data
return object() # type: ignore[return-value]
def test_requires_llm_mode(self, factory):
node_data = SimpleNamespace(
model=SimpleNamespace(
provider="provider",
name="model",
mode="",
completion_params={},
)
)
def _capture_memory(
self: DifyNodeFactory,
*,
node_data: object,
model_instance: ModelInstance,
) -> None:
captured["memory_node_data"] = node_data
with pytest.raises(node_factory.LLMModeRequiredError, match="LLM mode is required"):
factory._build_model_instance_for_llm_node(node_data)
monkeypatch.setattr(DifyNodeFactory, "_build_model_instance_for_llm_node", _capture_model_instance)
monkeypatch.setattr(DifyNodeFactory, "_build_memory_for_llm_node", _capture_memory)
def test_raises_when_provider_model_is_missing(self, factory, llm_model_setup):
setup = llm_model_setup(has_provider_model=False)
node = factory.create_node(llm_node_config)
with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"):
factory._build_model_instance_for_llm_node(setup.node_data)
assert isinstance(captured["node_data"], _FactoryLLMNodeData)
assert isinstance(captured["memory_node_data"], _FactoryLLMNodeData)
assert isinstance(node.node_data, _FactoryLLMNodeData)
def test_raises_when_model_schema_is_missing(self, factory, llm_model_setup):
setup = llm_model_setup(model_schema=None)
with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"):
factory._build_model_instance_for_llm_node(setup.node_data)
setup.provider_model.raise_for_status.assert_called_once()
def test_builds_model_instance_and_normalizes_stop_tokens(self, factory, llm_model_setup):
setup = llm_model_setup(
completion_params={"temperature": 0.3, "stop": "not-a-list"},
model_schema={"schema": "value"},
)
result = factory._build_model_instance_for_llm_node(setup.node_data)
assert result is setup.model_instance
assert result.provider == "provider"
assert result.model_name == "model"
assert result.credentials == setup.credentials
assert result.parameters == {"temperature": 0.3}
assert result.stop == ()
assert result.model_type_instance is setup.model_type_instance
setup.provider_model.raise_for_status.assert_called_once()
class TestDifyNodeFactoryMemory:
@pytest.fixture
def factory(self):
factory = object.__new__(node_factory.DifyNodeFactory)
factory._dify_context = SimpleNamespace(app_id="app-id")
factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock())
return factory
def test_returns_none_when_memory_is_not_configured(self, factory):
result = factory._build_memory_for_llm_node(
node_data=SimpleNamespace(memory=None),
model_instance=sentinel.model_instance,
)
assert result is None
factory.graph_runtime_state.variable_pool.get.assert_not_called()
def test_uses_string_segment_conversation_id(self, monkeypatch, factory):
memory_config = sentinel.memory_config
factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="conversation-id")
fetch_memory = MagicMock(return_value=sentinel.memory)
monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory)
result = factory._build_memory_for_llm_node(
node_data=SimpleNamespace(memory=memory_config),
model_instance=sentinel.model_instance,
)
assert result is sentinel.memory
factory.graph_runtime_state.variable_pool.get.assert_called_once_with(
["sys", SystemVariableKey.CONVERSATION_ID]
)
fetch_memory.assert_called_once_with(
conversation_id="conversation-id",
app_id="app-id",
node_data_memory=memory_config,
model_instance=sentinel.model_instance,
)
def test_ignores_non_string_segment_conversation_ids(self, monkeypatch, factory):
memory_config = sentinel.memory_config
factory.graph_runtime_state.variable_pool.get.return_value = sentinel.segment
fetch_memory = MagicMock(return_value=sentinel.memory)
monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory)
result = factory._build_memory_for_llm_node(
node_data=SimpleNamespace(memory=memory_config),
model_instance=sentinel.model_instance,
)
assert result is sentinel.memory
fetch_memory.assert_called_once_with(
conversation_id=None,
app_id="app-id",
node_data_memory=memory_config,
model_instance=sentinel.model_instance,
)

View File

@ -0,0 +1,656 @@
from collections import UserString
from types import SimpleNamespace
from unittest.mock import MagicMock, patch, sentinel
import pytest
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.workflow import workflow_entry
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.errors import WorkflowNodeRunFailedError
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.file.models import File
from dify_graph.graph_events import GraphRunFailedEvent
from dify_graph.nodes import NodeType
from dify_graph.runtime import ChildGraphNotFoundError
def _build_typed_node_config(node_type: NodeType):
return NodeConfigDictAdapter.validate_python({"id": "node-id", "data": {"type": node_type}})
class TestWorkflowChildEngineBuilder:
@pytest.mark.parametrize(
("graph_config", "node_id", "expected"),
[
({"nodes": [{"id": "root"}]}, "root", True),
({"nodes": [{"id": "root"}]}, "other", False),
({"nodes": "invalid"}, "root", None),
({"nodes": ["invalid"]}, "root", None),
],
)
def test_has_node_id(self, graph_config, node_id, expected):
result = workflow_entry._WorkflowChildEngineBuilder._has_node_id(graph_config, node_id)
assert result is expected
def test_build_child_engine_raises_when_root_node_is_missing(self):
builder = workflow_entry._WorkflowChildEngineBuilder()
with patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory):
with pytest.raises(ChildGraphNotFoundError, match="child graph root node 'missing' not found"):
builder.build_child_engine(
workflow_id="workflow-id",
graph_init_params=sentinel.graph_init_params,
graph_runtime_state=sentinel.graph_runtime_state,
graph_config={"nodes": []},
root_node_id="missing",
)
def test_build_child_engine_constructs_graph_engine_and_layers(self):
builder = workflow_entry._WorkflowChildEngineBuilder()
child_graph = sentinel.child_graph
child_engine = MagicMock()
quota_layer = sentinel.quota_layer
additional_layers = [sentinel.layer_one, sentinel.layer_two]
with (
patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory) as dify_node_factory,
patch.object(workflow_entry.Graph, "init", return_value=child_graph) as graph_init,
patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls,
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
patch.object(workflow_entry, "LLMQuotaLayer", return_value=quota_layer),
):
result = builder.build_child_engine(
workflow_id="workflow-id",
graph_init_params=sentinel.graph_init_params,
graph_runtime_state=sentinel.graph_runtime_state,
graph_config={"nodes": [{"id": "root"}]},
root_node_id="root",
layers=additional_layers,
)
assert result is child_engine
dify_node_factory.assert_called_once_with(
graph_init_params=sentinel.graph_init_params,
graph_runtime_state=sentinel.graph_runtime_state,
)
graph_init.assert_called_once_with(
graph_config={"nodes": [{"id": "root"}]},
node_factory=sentinel.factory,
root_node_id="root",
)
graph_engine_cls.assert_called_once_with(
workflow_id="workflow-id",
graph=child_graph,
graph_runtime_state=sentinel.graph_runtime_state,
command_channel=sentinel.command_channel,
config=sentinel.graph_engine_config,
child_engine_builder=builder,
)
assert child_engine.layer.call_args_list == [
((quota_layer,), {}),
((sentinel.layer_one,), {}),
((sentinel.layer_two,), {}),
]
class TestWorkflowEntryInit:
def test_rejects_call_depth_above_limit(self):
call_depth = workflow_entry.dify_config.WORKFLOW_CALL_MAX_DEPTH + 1
with pytest.raises(ValueError, match="Max workflow call depth"):
workflow_entry.WorkflowEntry(
tenant_id="tenant-id",
app_id="app-id",
workflow_id="workflow-id",
graph_config={"nodes": [], "edges": []},
graph=sentinel.graph,
user_id="user-id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=call_depth,
variable_pool=sentinel.variable_pool,
graph_runtime_state=sentinel.graph_runtime_state,
)
def test_applies_debug_and_observability_layers(self):
graph_engine = MagicMock()
debug_layer = sentinel.debug_layer
execution_limits_layer = sentinel.execution_limits_layer
llm_quota_layer = sentinel.llm_quota_layer
observability_layer = sentinel.observability_layer
with (
patch.object(workflow_entry.dify_config, "DEBUG", True),
patch.object(workflow_entry.dify_config, "ENABLE_OTEL", False),
patch.object(workflow_entry, "is_instrument_flag_enabled", return_value=True),
patch.object(workflow_entry, "GraphEngine", return_value=graph_engine) as graph_engine_cls,
patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config),
patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel),
patch.object(workflow_entry, "DebugLoggingLayer", return_value=debug_layer) as debug_logging_layer,
patch.object(
workflow_entry,
"ExecutionLimitsLayer",
return_value=execution_limits_layer,
) as execution_limits_layer_cls,
patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer),
patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer),
):
entry = workflow_entry.WorkflowEntry(
tenant_id="tenant-id",
app_id="app-id",
workflow_id="workflow-id-123456",
graph_config={"nodes": [], "edges": []},
graph=sentinel.graph,
user_id="user-id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
variable_pool=sentinel.variable_pool,
graph_runtime_state=sentinel.graph_runtime_state,
command_channel=None,
)
assert entry.command_channel is sentinel.command_channel
graph_engine_cls.assert_called_once_with(
workflow_id="workflow-id-123456",
graph=sentinel.graph,
graph_runtime_state=sentinel.graph_runtime_state,
command_channel=sentinel.command_channel,
config=sentinel.graph_engine_config,
child_engine_builder=entry._child_engine_builder,
)
debug_logging_layer.assert_called_once_with(
level="DEBUG",
include_inputs=True,
include_outputs=True,
include_process_data=False,
logger_name="GraphEngine.Debug.workflow",
)
execution_limits_layer_cls.assert_called_once_with(
max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
assert graph_engine.layer.call_args_list == [
((debug_layer,), {}),
((execution_limits_layer,), {}),
((llm_quota_layer,), {}),
((observability_layer,), {}),
]
class TestWorkflowEntryRun:
def test_run_swallows_generate_task_stopped_errors(self):
entry = object.__new__(workflow_entry.WorkflowEntry)
entry.graph_engine = MagicMock()
entry.graph_engine.run.side_effect = GenerateTaskStoppedError()
assert list(entry.run()) == []
def test_run_emits_failed_event_for_unexpected_errors(self):
entry = object.__new__(workflow_entry.WorkflowEntry)
entry.graph_engine = MagicMock()
entry.graph_engine.run.side_effect = RuntimeError("boom")
events = list(entry.run())
assert len(events) == 1
assert isinstance(events[0], GraphRunFailedEvent)
assert events[0].error == "boom"
class TestWorkflowEntrySingleStepRun:
def test_uses_empty_mapping_when_selector_extraction_is_not_implemented(self):
class FakeNode:
id = "node-id"
title = "Node Title"
node_type = "fake"
@staticmethod
def version():
return "1"
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
raise NotImplementedError
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
patch.object(
workflow_entry.WorkflowEntry,
"mapping_user_inputs_to_variable_pool",
) as mapping_user_inputs_to_variable_pool,
patch.object(
workflow_entry.WorkflowEntry,
"_traced_node_run",
return_value=iter(["event"]),
),
):
dify_node_factory.return_value.create_node.return_value = FakeNode()
workflow = SimpleNamespace(
tenant_id="tenant-id",
app_id="app-id",
id="workflow-id",
graph_dict={"nodes": [], "edges": []},
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START),
)
node, generator = workflow_entry.WorkflowEntry.single_step_run(
workflow=workflow,
node_id="node-id",
user_id="user-id",
user_inputs={"question": "hello"},
variable_pool=sentinel.variable_pool,
)
assert node.id == "node-id"
assert list(generator) == ["event"]
load_into_variable_pool.assert_called_once_with(
variable_loader=workflow_entry.DUMMY_VARIABLE_LOADER,
variable_pool=sentinel.variable_pool,
variable_mapping={},
user_inputs={"question": "hello"},
)
mapping_user_inputs_to_variable_pool.assert_called_once_with(
variable_mapping={},
user_inputs={"question": "hello"},
variable_pool=sentinel.variable_pool,
tenant_id="tenant-id",
)
def test_skips_user_input_mapping_for_datasource_nodes(self):
class FakeDatasourceNode:
id = "node-id"
node_type = "datasource"
@staticmethod
def version():
return "1"
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {"question": ["node", "question"]}
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
patch.object(
workflow_entry.WorkflowEntry,
"mapping_user_inputs_to_variable_pool",
) as mapping_user_inputs_to_variable_pool,
patch.object(
workflow_entry.WorkflowEntry,
"_traced_node_run",
return_value=iter(["event"]),
),
):
dify_node_factory.return_value.create_node.return_value = FakeDatasourceNode()
workflow = SimpleNamespace(
tenant_id="tenant-id",
app_id="app-id",
id="workflow-id",
graph_dict={"nodes": [], "edges": []},
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.DATASOURCE),
)
node, generator = workflow_entry.WorkflowEntry.single_step_run(
workflow=workflow,
node_id="node-id",
user_id="user-id",
user_inputs={"question": "hello"},
variable_pool=sentinel.variable_pool,
)
assert node.id == "node-id"
assert list(generator) == ["event"]
load_into_variable_pool.assert_called_once()
mapping_user_inputs_to_variable_pool.assert_not_called()
def test_wraps_traced_node_run_failures(self):
class FakeNode:
id = "node-id"
title = "Node Title"
node_type = "fake"
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
@staticmethod
def version():
return "1"
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry, "load_into_variable_pool"),
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
patch.object(
workflow_entry.WorkflowEntry,
"_traced_node_run",
side_effect=RuntimeError("boom"),
),
):
dify_node_factory.return_value.create_node.return_value = FakeNode()
workflow = SimpleNamespace(
tenant_id="tenant-id",
app_id="app-id",
id="workflow-id",
graph_dict={"nodes": [], "edges": []},
get_node_config_by_id=lambda _node_id: _build_typed_node_config(NodeType.START),
)
with pytest.raises(WorkflowNodeRunFailedError):
workflow_entry.WorkflowEntry.single_step_run(
workflow=workflow,
node_id="node-id",
user_id="user-id",
user_inputs={},
variable_pool=sentinel.variable_pool,
)
class TestWorkflowEntryHelpers:
def test_create_single_node_graph_builds_start_edge(self):
graph = workflow_entry.WorkflowEntry._create_single_node_graph(
node_id="target-node",
node_data={"type": NodeType.PARAMETER_EXTRACTOR},
node_width=320,
node_height=180,
)
assert graph["nodes"][0]["id"] == "start"
assert graph["nodes"][1]["id"] == "target-node"
assert graph["nodes"][1]["width"] == 320
assert graph["nodes"][1]["height"] == 180
assert graph["edges"] == [
{
"source": "start",
"target": "target-node",
"sourceHandle": "source",
"targetHandle": "target",
}
]
def test_run_free_node_rejects_unsupported_types(self):
with pytest.raises(ValueError, match="Node type start not supported"):
workflow_entry.WorkflowEntry.run_free_node(
node_data={"type": NodeType.START.value},
node_id="node-id",
tenant_id="tenant-id",
user_id="user-id",
user_inputs={},
)
def test_run_free_node_rejects_missing_node_class(self, monkeypatch):
monkeypatch.setattr(
workflow_entry,
"NODE_TYPE_CLASSES_MAPPING",
{NodeType.PARAMETER_EXTRACTOR: {"1": None}},
)
with pytest.raises(ValueError, match="Node class not found for node type parameter-extractor"):
workflow_entry.WorkflowEntry.run_free_node(
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value},
node_id="node-id",
tenant_id="tenant-id",
user_id="user-id",
user_inputs={},
)
def test_run_free_node_uses_empty_mapping_when_selector_extraction_is_not_implemented(self, monkeypatch):
class FakeNodeClass:
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
raise NotImplementedError
class FakeNode:
id = "node-id"
title = "Node Title"
node_type = "parameter-extractor"
@staticmethod
def version():
return "1"
dify_node_factory = MagicMock()
dify_node_factory.create_node.return_value = FakeNode()
monkeypatch.setattr(
workflow_entry,
"NODE_TYPE_CLASSES_MAPPING",
{NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}},
)
with (
patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables),
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls,
patch.object(
workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params
) as graph_init_params,
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(
workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}
) as build_dify_run_context,
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls,
patch.object(
workflow_entry.WorkflowEntry,
"mapping_user_inputs_to_variable_pool",
) as mapping_user_inputs_to_variable_pool,
patch.object(
workflow_entry.WorkflowEntry,
"_traced_node_run",
return_value=iter(["event"]),
),
):
node, generator = workflow_entry.WorkflowEntry.run_free_node(
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"},
node_id="node-id",
tenant_id="tenant-id",
user_id="user-id",
user_inputs={"question": "hello"},
)
assert node.id == "node-id"
assert list(generator) == ["event"]
variable_pool_cls.assert_called_once_with(
system_variables=sentinel.system_variables,
user_inputs={},
environment_variables=[],
)
build_dify_run_context.assert_called_once_with(
tenant_id="tenant-id",
app_id="",
user_id="user-id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_params.assert_called_once_with(
workflow_id="",
graph_config=workflow_entry.WorkflowEntry._create_single_node_graph(
"node-id", {"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"}
),
run_context={"_dify": "context"},
call_depth=0,
)
dify_node_factory_cls.assert_called_once_with(
graph_init_params=sentinel.graph_init_params,
graph_runtime_state=sentinel.graph_runtime_state,
)
mapping_user_inputs_to_variable_pool.assert_called_once_with(
variable_mapping={},
user_inputs={"question": "hello"},
variable_pool=sentinel.variable_pool,
tenant_id="tenant-id",
)
def test_run_free_node_wraps_execution_failures(self, monkeypatch):
class FakeNodeClass:
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
class FakeNode:
id = "node-id"
title = "Node Title"
node_type = "parameter-extractor"
@staticmethod
def version():
return "1"
dify_node_factory = MagicMock()
dify_node_factory.create_node.return_value = FakeNode()
monkeypatch.setattr(
workflow_entry,
"NODE_TYPE_CLASSES_MAPPING",
{NodeType.PARAMETER_EXTRACTOR: {"1": FakeNodeClass}},
)
with (
patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables),
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool),
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory),
patch.object(
workflow_entry.WorkflowEntry,
"mapping_user_inputs_to_variable_pool",
side_effect=RuntimeError("boom"),
),
):
with pytest.raises(WorkflowNodeRunFailedError, match="Node Title run failed: boom"):
workflow_entry.WorkflowEntry.run_free_node(
node_data={"type": NodeType.PARAMETER_EXTRACTOR.value, "title": "Node"},
node_id="node-id",
tenant_id="tenant-id",
user_id="user-id",
user_inputs={"question": "hello"},
)
def test_handle_special_values_serializes_nested_files(self):
file = File(
tenant_id="tenant-id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.png",
filename="image.png",
extension=".png",
)
result = workflow_entry.WorkflowEntry.handle_special_values({"file": file, "nested": {"files": [file]}})
assert result == {
"file": file.to_dict(),
"nested": {"files": [file.to_dict()]},
}
def test_handle_special_values_returns_none_for_none(self):
assert workflow_entry.WorkflowEntry._handle_special_values(None) is None
def test_handle_special_values_returns_scalar_as_is(self):
assert workflow_entry.WorkflowEntry._handle_special_values("plain-text") == "plain-text"
class TestMappingUserInputsBranches:
def test_rejects_invalid_node_variable_key(self):
class EmptySplitKey(UserString):
def split(self, _sep=None):
return []
with pytest.raises(ValueError, match="Invalid node variable broken"):
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping={EmptySplitKey("broken"): ["node", "input"]},
user_inputs={},
variable_pool=MagicMock(),
tenant_id="tenant-id",
)
def test_skips_none_user_input_when_variable_already_exists(self):
variable_pool = MagicMock()
variable_pool.get.return_value = None
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping={"node.input": ["target", "input"]},
user_inputs={"node.input": None},
variable_pool=variable_pool,
tenant_id="tenant-id",
)
variable_pool.add.assert_not_called()
def test_merges_structured_output_values(self):
variable_pool = MagicMock()
variable_pool.get.side_effect = [
None,
SimpleNamespace(value={"existing": "value"}),
]
workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping={"node.answer": ["target", "structured_output", "answer"]},
user_inputs={"node.answer": "new-value"},
variable_pool=variable_pool,
tenant_id="tenant-id",
)
variable_pool.add.assert_called_once_with(
["target", "structured_output"],
{"existing": "value", "answer": "new-value"},
)
class TestWorkflowEntryTracing:
def test_traced_node_run_reports_success(self):
layer = MagicMock()
class FakeNode:
def ensure_execution_id(self):
return None
def run(self):
yield "event"
with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer):
events = list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode()))
assert events == ["event"]
layer.on_graph_start.assert_called_once_with()
layer.on_node_run_start.assert_called_once()
layer.on_node_run_end.assert_called_once_with(
layer.on_node_run_start.call_args.args[0],
None,
)
def test_traced_node_run_reports_errors(self):
layer = MagicMock()
class FakeNode:
def ensure_execution_id(self):
return None
def run(self):
raise RuntimeError("boom")
yield
with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer):
with pytest.raises(RuntimeError, match="boom"):
list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode()))
assert isinstance(layer.on_node_run_end.call_args.args[1], RuntimeError)

View File

@ -311,7 +311,9 @@ class TestWorkflowService:
mock_workflow.conversation_variables = []
# Mock node config
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python(
{"id": "node-1", "data": {"type": NodeType.LLM.value}}
)
mock_workflow.get_enclosing_node_type_and_id.return_value = None
# Mock class methods
@ -376,7 +378,9 @@ class TestWorkflowService:
mock_workflow.tenant_id = "tenant-1"
mock_workflow.environment_variables = []
mock_workflow.conversation_variables = []
mock_workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "llm"}}
mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python(
{"id": "node-1", "data": {"type": NodeType.LLM.value}}
)
mock_workflow.get_enclosing_node_type_and_id.return_value = None
monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock())