mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 16:08:04 +08:00
test(api): cover remaining workflow typing branches
This commit is contained in:
@ -0,0 +1,110 @@
|
||||
from collections.abc import Iterator
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse
|
||||
from core.errors.error import QuotaExceededError
|
||||
|
||||
|
||||
class DummyResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = AppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, str]:
|
||||
return {"mode": "blocking-full", "task_id": blocking_response.task_id}
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, str]:
|
||||
return {"mode": "blocking-simple", "task_id": blocking_response.task_id}
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Iterator[object]):
|
||||
for _ in stream_response:
|
||||
yield {"mode": "stream-full"}
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Iterator[object]):
|
||||
for _ in stream_response:
|
||||
yield {"mode": "stream-simple"}
|
||||
|
||||
|
||||
def test_convert_routes_to_full_or_simple_modes() -> None:
|
||||
blocking = AppBlockingResponse(task_id="task-1")
|
||||
|
||||
assert DummyResponseConverter.convert(blocking, InvokeFrom.DEBUGGER) == {
|
||||
"mode": "blocking-full",
|
||||
"task_id": "task-1",
|
||||
}
|
||||
assert DummyResponseConverter.convert(blocking, InvokeFrom.WEB_APP) == {
|
||||
"mode": "blocking-simple",
|
||||
"task_id": "task-1",
|
||||
}
|
||||
assert list(DummyResponseConverter.convert(iter([object()]), InvokeFrom.SERVICE_API)) == [{"mode": "stream-full"}]
|
||||
assert list(DummyResponseConverter.convert(iter([object()]), InvokeFrom.WEB_APP)) == [{"mode": "stream-simple"}]
|
||||
|
||||
|
||||
def test_get_simple_metadata_preserves_new_retriever_fields() -> None:
|
||||
metadata = {
|
||||
"retriever_resources": [
|
||||
{
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": "Dataset",
|
||||
"document_id": "document-1",
|
||||
"segment_id": "segment-1",
|
||||
"position": 1,
|
||||
"data_source_type": "upload_file",
|
||||
"document_name": "Document",
|
||||
"score": 0.9,
|
||||
"hit_count": 2,
|
||||
"word_count": 128,
|
||||
"segment_position": 3,
|
||||
"index_node_hash": "hash",
|
||||
"content": "content",
|
||||
"page": 5,
|
||||
"title": "Title",
|
||||
"files": [{"id": "file-1"}],
|
||||
"summary": "summary",
|
||||
}
|
||||
],
|
||||
"annotation_reply": "hidden",
|
||||
"usage": {"latency": 0.1},
|
||||
}
|
||||
|
||||
result = DummyResponseConverter._get_simple_metadata(metadata)
|
||||
|
||||
assert result == {
|
||||
"retriever_resources": [
|
||||
{
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": "Dataset",
|
||||
"document_id": "document-1",
|
||||
"segment_id": "segment-1",
|
||||
"position": 1,
|
||||
"data_source_type": "upload_file",
|
||||
"document_name": "Document",
|
||||
"score": 0.9,
|
||||
"hit_count": 2,
|
||||
"word_count": 128,
|
||||
"segment_position": 3,
|
||||
"index_node_hash": "hash",
|
||||
"content": "content",
|
||||
"page": 5,
|
||||
"title": "Title",
|
||||
"files": [{"id": "file-1"}],
|
||||
"summary": "summary",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_error_to_stream_response_uses_specific_and_fallback_mappings() -> None:
|
||||
quota_response = DummyResponseConverter._error_to_stream_response(QuotaExceededError())
|
||||
fallback_response = DummyResponseConverter._error_to_stream_response(RuntimeError("boom"))
|
||||
|
||||
assert quota_response["code"] == "provider_quota_exceeded"
|
||||
assert quota_response["status"] == 400
|
||||
assert fallback_response == {
|
||||
"code": "internal_server_error",
|
||||
"message": "Internal Server Error, please contact support.",
|
||||
"status": 500,
|
||||
}
|
||||
@ -0,0 +1,46 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
|
||||
|
||||
def test_agent_input_accepts_variable_selector_and_mixed_values() -> None:
|
||||
node_data = AgentNodeData.model_validate(
|
||||
{
|
||||
"title": "Agent",
|
||||
"agent_strategy_provider_name": "provider",
|
||||
"agent_strategy_name": "strategy",
|
||||
"agent_strategy_label": "Strategy",
|
||||
"agent_parameters": {
|
||||
"query": {"type": "variable", "value": ["start", "query"]},
|
||||
"tools": {"type": "mixed", "value": [{"provider": "builtin", "name": "search"}]},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert node_data.agent_parameters["query"].value == ["start", "query"]
|
||||
assert node_data.agent_parameters["tools"].value == [{"provider": "builtin", "name": "search"}]
|
||||
|
||||
|
||||
def test_agent_input_rejects_invalid_variable_selector_and_unknown_type() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentNodeData.model_validate(
|
||||
{
|
||||
"title": "Agent",
|
||||
"agent_strategy_provider_name": "provider",
|
||||
"agent_strategy_name": "strategy",
|
||||
"agent_strategy_label": "Strategy",
|
||||
"agent_parameters": {"query": {"type": "variable", "value": "start.query"}},
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError, match="Unknown agent input type"):
|
||||
AgentNodeData.model_validate(
|
||||
{
|
||||
"title": "Agent",
|
||||
"agent_strategy_provider_name": "provider",
|
||||
"agent_strategy_name": "strategy",
|
||||
"agent_strategy_label": "Strategy",
|
||||
"agent_parameters": {"query": {"type": "unsupported", "value": "hello"}},
|
||||
}
|
||||
)
|
||||
@ -15,6 +15,7 @@ from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultWithStructuredOutput, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
@ -120,6 +121,54 @@ def test_prompt_config_converts_none_jinja_variables() -> None:
|
||||
assert prompt_config.jinja2_variables == []
|
||||
|
||||
|
||||
def test_fetch_structured_output_schema_validates_required_object_shape() -> None:
|
||||
assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object", "a": 1}}) == {
|
||||
"type": "object",
|
||||
"a": 1,
|
||||
}
|
||||
|
||||
with pytest.raises(Exception, match="valid structured output schema"):
|
||||
LLMNode.fetch_structured_output_schema(structured_output={"schema": None})
|
||||
|
||||
|
||||
def test_handle_blocking_result_separates_reasoning_and_structured_output() -> None:
|
||||
saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
event = LLMNode.handle_blocking_result(
|
||||
invoke_result=LLMResultWithStructuredOutput(
|
||||
model="gpt",
|
||||
message=AssistantPromptMessage(content="<think>reasoning</think>answer"),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
structured_output={"answer": "done"},
|
||||
),
|
||||
saver=saver,
|
||||
file_outputs=[],
|
||||
reasoning_format="separated",
|
||||
request_latency=1.2345,
|
||||
)
|
||||
|
||||
assert event.text == "answer"
|
||||
assert event.reasoning_content == "reasoning"
|
||||
assert event.structured_output == {"answer": "done"}
|
||||
assert event.usage.latency == 1.234
|
||||
|
||||
|
||||
def test_handle_blocking_result_keeps_tagged_text_without_structured_output() -> None:
|
||||
saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
event = LLMNode.handle_blocking_result(
|
||||
invoke_result=LLMResult(
|
||||
model="gpt",
|
||||
message=AssistantPromptMessage(content="plain text"),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
),
|
||||
saver=saver,
|
||||
file_outputs=[],
|
||||
)
|
||||
|
||||
assert event.text == "plain text"
|
||||
assert event.reasoning_content == ""
|
||||
assert event.structured_output is None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node_data() -> LLMNodeData:
|
||||
return LLMNodeData(
|
||||
|
||||
@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -221,3 +221,67 @@ def test_tool_node_data_filters_missing_tool_parameter_values() -> None:
|
||||
)
|
||||
|
||||
assert set(node_data.tool_parameters.keys()) == {"query"}
|
||||
|
||||
|
||||
def test_generate_parameters_reads_variables_and_optional_missing_inputs(tool_node: ToolNode) -> None:
|
||||
variable_pool = MagicMock()
|
||||
variable_pool.get.side_effect = [MagicMock(value="from-variable"), None]
|
||||
node_data = ToolNodeData.model_validate(
|
||||
{
|
||||
"title": "Tool",
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"provider_name": "provider",
|
||||
"tool_name": "tool",
|
||||
"tool_label": "tool",
|
||||
"tool_configurations": {},
|
||||
"tool_parameters": {
|
||||
"query": {"type": "variable", "value": ["start", "query"]},
|
||||
"optional": {"type": "variable", "value": ["start", "optional"]},
|
||||
},
|
||||
}
|
||||
)
|
||||
tool_parameters = [
|
||||
ToolParameter.get_simple_instance("query", "query", ToolParameter.ToolParameterType.STRING, True),
|
||||
ToolParameter.get_simple_instance("optional", "optional", ToolParameter.ToolParameterType.STRING, False),
|
||||
]
|
||||
|
||||
result = tool_node._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=variable_pool,
|
||||
node_data=node_data,
|
||||
)
|
||||
|
||||
assert result == {"query": "from-variable"}
|
||||
|
||||
|
||||
def test_generate_parameters_formats_logs_and_unknown_parameters(tool_node: ToolNode) -> None:
|
||||
variable_pool = MagicMock()
|
||||
variable_pool.convert_template.return_value = MagicMock(text="rendered", log="masked")
|
||||
node_data = ToolNodeData.model_validate(
|
||||
{
|
||||
"title": "Tool",
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"provider_name": "provider",
|
||||
"tool_name": "tool",
|
||||
"tool_label": "tool",
|
||||
"tool_configurations": {},
|
||||
"tool_parameters": {
|
||||
"query": {"type": "mixed", "value": "{{ question }}"},
|
||||
"missing": {"type": "constant", "value": "literal"},
|
||||
},
|
||||
}
|
||||
)
|
||||
tool_parameters = [
|
||||
ToolParameter.get_simple_instance("query", "query", ToolParameter.ToolParameterType.STRING, True),
|
||||
]
|
||||
|
||||
result = tool_node._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=variable_pool,
|
||||
node_data=node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
assert result == {"query": "masked", "missing": None}
|
||||
|
||||
@ -97,6 +97,22 @@ class TestWorkflowChildEngineBuilder:
|
||||
((sentinel.layer_two,), {}),
|
||||
]
|
||||
|
||||
def test_build_child_engine_tolerates_invalid_graph_shape_until_graph_init(self):
|
||||
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory),
|
||||
patch.object(workflow_entry.Graph, "init", side_effect=ValueError("invalid graph")),
|
||||
):
|
||||
with pytest.raises(ValueError, match="invalid graph"):
|
||||
builder.build_child_engine(
|
||||
workflow_id="workflow-id",
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
graph_config={"nodes": "invalid"},
|
||||
root_node_id="root",
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowEntryInit:
|
||||
def test_rejects_call_depth_above_limit(self):
|
||||
|
||||
Reference in New Issue
Block a user