test(api): cover remaining workflow typing branches

This commit is contained in:
Yanli 盐粒
2026-03-25 19:46:47 +08:00
parent e819a9a5f7
commit b7a5ed6c0b
5 changed files with 286 additions and 1 deletions

View File

@ -0,0 +1,110 @@
from collections.abc import Iterator
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse
from core.errors.error import QuotaExceededError
class DummyResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = AppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, str]:
return {"mode": "blocking-full", "task_id": blocking_response.task_id}
@classmethod
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, str]:
return {"mode": "blocking-simple", "task_id": blocking_response.task_id}
@classmethod
def convert_stream_full_response(cls, stream_response: Iterator[object]):
for _ in stream_response:
yield {"mode": "stream-full"}
@classmethod
def convert_stream_simple_response(cls, stream_response: Iterator[object]):
for _ in stream_response:
yield {"mode": "stream-simple"}
def test_convert_routes_to_full_or_simple_modes() -> None:
blocking = AppBlockingResponse(task_id="task-1")
assert DummyResponseConverter.convert(blocking, InvokeFrom.DEBUGGER) == {
"mode": "blocking-full",
"task_id": "task-1",
}
assert DummyResponseConverter.convert(blocking, InvokeFrom.WEB_APP) == {
"mode": "blocking-simple",
"task_id": "task-1",
}
assert list(DummyResponseConverter.convert(iter([object()]), InvokeFrom.SERVICE_API)) == [{"mode": "stream-full"}]
assert list(DummyResponseConverter.convert(iter([object()]), InvokeFrom.WEB_APP)) == [{"mode": "stream-simple"}]
def test_get_simple_metadata_preserves_new_retriever_fields() -> None:
metadata = {
"retriever_resources": [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset",
"document_id": "document-1",
"segment_id": "segment-1",
"position": 1,
"data_source_type": "upload_file",
"document_name": "Document",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "hash",
"content": "content",
"page": 5,
"title": "Title",
"files": [{"id": "file-1"}],
"summary": "summary",
}
],
"annotation_reply": "hidden",
"usage": {"latency": 0.1},
}
result = DummyResponseConverter._get_simple_metadata(metadata)
assert result == {
"retriever_resources": [
{
"dataset_id": "dataset-1",
"dataset_name": "Dataset",
"document_id": "document-1",
"segment_id": "segment-1",
"position": 1,
"data_source_type": "upload_file",
"document_name": "Document",
"score": 0.9,
"hit_count": 2,
"word_count": 128,
"segment_position": 3,
"index_node_hash": "hash",
"content": "content",
"page": 5,
"title": "Title",
"files": [{"id": "file-1"}],
"summary": "summary",
}
]
}
def test_error_to_stream_response_uses_specific_and_fallback_mappings() -> None:
quota_response = DummyResponseConverter._error_to_stream_response(QuotaExceededError())
fallback_response = DummyResponseConverter._error_to_stream_response(RuntimeError("boom"))
assert quota_response["code"] == "provider_quota_exceeded"
assert quota_response["status"] == 400
assert fallback_response == {
"code": "internal_server_error",
"message": "Internal Server Error, please contact support.",
"status": 500,
}

View File

@ -0,0 +1,46 @@
import pytest
from pydantic import ValidationError
from core.workflow.nodes.agent.entities import AgentNodeData
def test_agent_input_accepts_variable_selector_and_mixed_values() -> None:
node_data = AgentNodeData.model_validate(
{
"title": "Agent",
"agent_strategy_provider_name": "provider",
"agent_strategy_name": "strategy",
"agent_strategy_label": "Strategy",
"agent_parameters": {
"query": {"type": "variable", "value": ["start", "query"]},
"tools": {"type": "mixed", "value": [{"provider": "builtin", "name": "search"}]},
},
}
)
assert node_data.agent_parameters["query"].value == ["start", "query"]
assert node_data.agent_parameters["tools"].value == [{"provider": "builtin", "name": "search"}]
def test_agent_input_rejects_invalid_variable_selector_and_unknown_type() -> None:
with pytest.raises(ValidationError):
AgentNodeData.model_validate(
{
"title": "Agent",
"agent_strategy_provider_name": "provider",
"agent_strategy_name": "strategy",
"agent_strategy_label": "Strategy",
"agent_parameters": {"query": {"type": "variable", "value": "start.query"}},
}
)
with pytest.raises(ValidationError, match="Unknown agent input type"):
AgentNodeData.model_validate(
{
"title": "Agent",
"agent_strategy_provider_name": "provider",
"agent_strategy_name": "strategy",
"agent_strategy_label": "Strategy",
"agent_parameters": {"query": {"type": "unsupported", "value": "hello"}},
}
)

View File

@ -15,6 +15,7 @@ from dify_graph.entities import GraphInitParams
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities import LLMMode
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultWithStructuredOutput, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
@ -120,6 +121,54 @@ def test_prompt_config_converts_none_jinja_variables() -> None:
assert prompt_config.jinja2_variables == []
def test_fetch_structured_output_schema_validates_required_object_shape() -> None:
assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object", "a": 1}}) == {
"type": "object",
"a": 1,
}
with pytest.raises(Exception, match="valid structured output schema"):
LLMNode.fetch_structured_output_schema(structured_output={"schema": None})
def test_handle_blocking_result_separates_reasoning_and_structured_output() -> None:
saver = mock.MagicMock(spec=LLMFileSaver)
event = LLMNode.handle_blocking_result(
invoke_result=LLMResultWithStructuredOutput(
model="gpt",
message=AssistantPromptMessage(content="<think>reasoning</think>answer"),
usage=LLMUsage.empty_usage(),
structured_output={"answer": "done"},
),
saver=saver,
file_outputs=[],
reasoning_format="separated",
request_latency=1.2345,
)
assert event.text == "answer"
assert event.reasoning_content == "reasoning"
assert event.structured_output == {"answer": "done"}
assert event.usage.latency == 1.234
def test_handle_blocking_result_keeps_tagged_text_without_structured_output() -> None:
saver = mock.MagicMock(spec=LLMFileSaver)
event = LLMNode.handle_blocking_result(
invoke_result=LLMResult(
model="gpt",
message=AssistantPromptMessage(content="plain text"),
usage=LLMUsage.empty_usage(),
),
saver=saver,
file_outputs=[],
)
assert event.text == "plain text"
assert event.reasoning_content == ""
assert event.structured_output is None
@pytest.fixture
def llm_node_data() -> LLMNodeData:
return LLMNodeData(

View File

@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
@ -221,3 +221,67 @@ def test_tool_node_data_filters_missing_tool_parameter_values() -> None:
)
assert set(node_data.tool_parameters.keys()) == {"query"}
def test_generate_parameters_reads_variables_and_optional_missing_inputs(tool_node: ToolNode) -> None:
variable_pool = MagicMock()
variable_pool.get.side_effect = [MagicMock(value="from-variable"), None]
node_data = ToolNodeData.model_validate(
{
"title": "Tool",
"provider_id": "provider",
"provider_type": "builtin",
"provider_name": "provider",
"tool_name": "tool",
"tool_label": "tool",
"tool_configurations": {},
"tool_parameters": {
"query": {"type": "variable", "value": ["start", "query"]},
"optional": {"type": "variable", "value": ["start", "optional"]},
},
}
)
tool_parameters = [
ToolParameter.get_simple_instance("query", "query", ToolParameter.ToolParameterType.STRING, True),
ToolParameter.get_simple_instance("optional", "optional", ToolParameter.ToolParameterType.STRING, False),
]
result = tool_node._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=variable_pool,
node_data=node_data,
)
assert result == {"query": "from-variable"}
def test_generate_parameters_formats_logs_and_unknown_parameters(tool_node: ToolNode) -> None:
variable_pool = MagicMock()
variable_pool.convert_template.return_value = MagicMock(text="rendered", log="masked")
node_data = ToolNodeData.model_validate(
{
"title": "Tool",
"provider_id": "provider",
"provider_type": "builtin",
"provider_name": "provider",
"tool_name": "tool",
"tool_label": "tool",
"tool_configurations": {},
"tool_parameters": {
"query": {"type": "mixed", "value": "{{ question }}"},
"missing": {"type": "constant", "value": "literal"},
},
}
)
tool_parameters = [
ToolParameter.get_simple_instance("query", "query", ToolParameter.ToolParameterType.STRING, True),
]
result = tool_node._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=variable_pool,
node_data=node_data,
for_log=True,
)
assert result == {"query": "masked", "missing": None}

View File

@ -97,6 +97,22 @@ class TestWorkflowChildEngineBuilder:
((sentinel.layer_two,), {}),
]
def test_build_child_engine_tolerates_invalid_graph_shape_until_graph_init(self):
builder = workflow_entry._WorkflowChildEngineBuilder()
with (
patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory),
patch.object(workflow_entry.Graph, "init", side_effect=ValueError("invalid graph")),
):
with pytest.raises(ValueError, match="invalid graph"):
builder.build_child_engine(
workflow_id="workflow-id",
graph_init_params=sentinel.graph_init_params,
graph_runtime_state=sentinel.graph_runtime_state,
graph_config={"nodes": "invalid"},
root_node_id="root",
)
class TestWorkflowEntryInit:
def test_rejects_call_depth_above_limit(self):