mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
test: add tests for llm node
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -20,16 +20,34 @@ from core.workflow.system_variables import default_system_variables
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from dify_graph.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from dify_graph.node_events import ModelInvokeCompletedEvent, RunRetrieverResourceEvent, StreamChunkEvent
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
@ -37,14 +55,29 @@ from dify_graph.nodes.llm.entities import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
PromptConfig,
|
||||
VisionConfig,
|
||||
VisionConfigOptions,
|
||||
)
|
||||
from dify_graph.nodes.llm.exc import (
|
||||
InvalidContextStructureError,
|
||||
LLMNodeError,
|
||||
NoPromptFoundError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
||||
from dify_graph.nodes.llm.node import LLMNode, _handle_completion_template, _handle_memory_completion_mode
|
||||
from dify_graph.nodes.llm.node import (
|
||||
LLMNode,
|
||||
_calculate_rest_token,
|
||||
_handle_completion_template,
|
||||
_handle_memory_chat_mode,
|
||||
_handle_memory_completion_mode,
|
||||
_render_jinja2_message,
|
||||
)
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.template_rendering import TemplateRenderError
|
||||
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
from models.provider import ProviderType
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
@ -80,6 +113,44 @@ def _build_prepared_llm_mock() -> mock.MagicMock:
|
||||
return model_instance
|
||||
|
||||
|
||||
def _build_model_schema(
|
||||
*,
|
||||
features: list[ModelFeature] | None = None,
|
||||
model_properties: dict[ModelPropertyKey, object] | None = None,
|
||||
parameter_rules: list[ParameterRule] | None = None,
|
||||
) -> AIModelEntity:
|
||||
return AIModelEntity(
|
||||
model="gpt-3.5-turbo",
|
||||
label=I18nObject(en_US="GPT-3.5 Turbo"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
features=features,
|
||||
model_properties=model_properties or {},
|
||||
parameter_rules=parameter_rules or [],
|
||||
)
|
||||
|
||||
|
||||
def _build_image_file(
|
||||
*,
|
||||
file_id: str,
|
||||
related_id: str,
|
||||
remote_url: str,
|
||||
extension: str = ".png",
|
||||
mime_type: str = "image/png",
|
||||
) -> File:
|
||||
return File(
|
||||
id=file_id,
|
||||
type=FileType.IMAGE,
|
||||
filename=f"{file_id}{extension}",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=remote_url,
|
||||
related_id=related_id,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node_data() -> LLMNodeData:
|
||||
return LLMNodeData(
|
||||
@ -679,6 +750,367 @@ def test_handle_list_messages_replaces_double_brace_context_placeholder(llm_node
|
||||
]
|
||||
|
||||
|
||||
def test_handle_list_messages_renders_jinja2_messages(llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify")
|
||||
renderer = mock.MagicMock()
|
||||
renderer.render_template.return_value = "Hello Dify"
|
||||
|
||||
prompt_messages = llm_node.handle_list_messages(
|
||||
messages=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="ignored",
|
||||
jinja2_text="Hello {{ name }}",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="jinja2",
|
||||
)
|
||||
],
|
||||
context="",
|
||||
jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])],
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
jinja2_template_renderer=renderer,
|
||||
)
|
||||
|
||||
assert prompt_messages == [
|
||||
SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]),
|
||||
]
|
||||
renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"})
|
||||
|
||||
|
||||
def test_transform_chat_messages_prefers_jinja2_text(llm_node):
|
||||
completion_template = LLMNodeCompletionModelPromptTemplate(
|
||||
text="ignored",
|
||||
jinja2_text="completion prompt",
|
||||
edition_type="jinja2",
|
||||
)
|
||||
chat_messages = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="ignored",
|
||||
jinja2_text="chat prompt",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="jinja2",
|
||||
),
|
||||
LLMNodeChatModelMessage(
|
||||
text="keep original",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
),
|
||||
]
|
||||
|
||||
transformed_completion = llm_node._transform_chat_messages(completion_template)
|
||||
transformed_messages = llm_node._transform_chat_messages(chat_messages)
|
||||
|
||||
assert transformed_completion.text == "completion prompt"
|
||||
assert transformed_messages[0].text == "chat prompt"
|
||||
assert transformed_messages[1].text == "keep original"
|
||||
|
||||
|
||||
def test_fetch_jinja_inputs_serializes_supported_segment_types(llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(
|
||||
["input", "items"],
|
||||
["alpha", {"metadata": {"_source": "knowledge"}, "content": "beta"}, 3],
|
||||
)
|
||||
llm_node.graph_runtime_state.variable_pool.add(
|
||||
["input", "context_doc"],
|
||||
{"metadata": {"_source": "knowledge"}, "content": "context body"},
|
||||
)
|
||||
llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"a": 1})
|
||||
|
||||
node_data = llm_node.node_data.model_copy(
|
||||
update={
|
||||
"prompt_config": PromptConfig(
|
||||
jinja2_variables=[
|
||||
VariableSelector(variable="items", value_selector=["input", "items"]),
|
||||
VariableSelector(variable="context_doc", value_selector=["input", "context_doc"]),
|
||||
VariableSelector(variable="payload", value_selector=["input", "payload"]),
|
||||
]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
assert llm_node._fetch_jinja_inputs(node_data) == {
|
||||
"items": "alpha\nbeta\n3",
|
||||
"context_doc": "context body",
|
||||
"payload": '{"a": 1}',
|
||||
}
|
||||
|
||||
|
||||
def test_fetch_jinja_inputs_raises_for_missing_variable(llm_node):
|
||||
node_data = llm_node.node_data.model_copy(
|
||||
update={
|
||||
"prompt_config": PromptConfig(
|
||||
jinja2_variables=[VariableSelector(variable="missing", value_selector=["input", "missing"])]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(VariableNotFoundError, match="Variable missing not found"):
|
||||
llm_node._fetch_jinja_inputs(node_data)
|
||||
|
||||
|
||||
def test_fetch_inputs_collects_prompt_and_memory_variables(llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify")
|
||||
llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"active": True})
|
||||
|
||||
node_data = llm_node.node_data.model_copy(
|
||||
update={
|
||||
"prompt_template": [
|
||||
LLMNodeChatModelMessage(
|
||||
text="Hello {{#input.name#}} with {{#input.payload#}}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
"memory": MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||
window=MemoryConfig.WindowConfig(enabled=True, size=1),
|
||||
query_prompt_template="Repeat {{#input.name#}}",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
assert llm_node._fetch_inputs(node_data) == {
|
||||
"#input.name#": "Dify",
|
||||
"#input.payload#": {"active": True},
|
||||
}
|
||||
|
||||
|
||||
def test_fetch_context_emits_string_context_event(llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["context", "value"], "retrieved context")
|
||||
node_data = llm_node.node_data.model_copy(
|
||||
update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])}
|
||||
)
|
||||
|
||||
events = list(llm_node._fetch_context(node_data))
|
||||
|
||||
assert events == [
|
||||
RunRetrieverResourceEvent(retriever_resources=[], context="retrieved context", context_files=[]),
|
||||
]
|
||||
|
||||
|
||||
def test_fetch_context_collects_retriever_resources_and_attachments(llm_node):
|
||||
attachment = _build_image_file(
|
||||
file_id="attachment",
|
||||
related_id="attachment-related",
|
||||
remote_url="https://example.com/attachment.png",
|
||||
)
|
||||
llm_node._retriever_attachment_loader = mock.MagicMock()
|
||||
llm_node._retriever_attachment_loader.load.return_value = [attachment]
|
||||
|
||||
llm_node.graph_runtime_state.variable_pool.add(
|
||||
["context", "value"],
|
||||
[
|
||||
{
|
||||
"content": "chunk body",
|
||||
"summary": "chunk summary",
|
||||
"files": [{"id": "file-1"}],
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": "dataset-1",
|
||||
"segment_id": "segment-1",
|
||||
"segment_word_count": 12,
|
||||
},
|
||||
},
|
||||
"tail text",
|
||||
],
|
||||
)
|
||||
node_data = llm_node.node_data.model_copy(
|
||||
update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])}
|
||||
)
|
||||
|
||||
events = list(llm_node._fetch_context(node_data))
|
||||
|
||||
assert len(events) == 1
|
||||
event = events[0]
|
||||
assert event.context == "chunk summary\nchunk body\ntail text"
|
||||
assert event.context_files == [attachment]
|
||||
assert event.retriever_resources == [
|
||||
{
|
||||
"position": None,
|
||||
"dataset_id": "dataset-1",
|
||||
"dataset_name": None,
|
||||
"document_id": None,
|
||||
"document_name": None,
|
||||
"data_source_type": None,
|
||||
"segment_id": "segment-1",
|
||||
"retriever_from": None,
|
||||
"score": None,
|
||||
"hit_count": None,
|
||||
"word_count": 12,
|
||||
"segment_position": None,
|
||||
"index_node_hash": None,
|
||||
"content": "chunk body",
|
||||
"page": None,
|
||||
"doc_metadata": None,
|
||||
"files": [{"id": "file-1"}],
|
||||
"summary": "chunk summary",
|
||||
}
|
||||
]
|
||||
llm_node._retriever_attachment_loader.load.assert_called_once_with(segment_id="segment-1")
|
||||
|
||||
|
||||
def test_fetch_context_rejects_invalid_context_structure(llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["context", "value"], [{"summary": "missing content"}])
|
||||
node_data = llm_node.node_data.model_copy(
|
||||
update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])}
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidContextStructureError, match="Invalid context structure"):
|
||||
list(llm_node._fetch_context(node_data))
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_chat_mode_appends_memory_query_and_files():
|
||||
model_instance = _build_prepared_llm_mock()
|
||||
model_instance.get_model_schema.return_value = _build_model_schema(features=[ModelFeature.VISION])
|
||||
|
||||
memory = mock.MagicMock(spec=MockTokenBufferMemory)
|
||||
memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history answer")]
|
||||
|
||||
sys_file = _build_image_file(file_id="sys-file", related_id="sys-related", remote_url="https://example.com/sys.png")
|
||||
context_file = _build_image_file(
|
||||
file_id="context-file",
|
||||
related_id="context-related",
|
||||
remote_url="https://example.com/context.png",
|
||||
)
|
||||
|
||||
prompt_content_side_effect = [
|
||||
ImagePromptMessageContent(
|
||||
url="https://example.com/sys.png",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
ImagePromptMessageContent(
|
||||
url="https://example.com/context.png",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
]
|
||||
|
||||
with mock.patch("dify_graph.nodes.llm.node.file_manager.to_prompt_message_content") as mock_to_prompt:
|
||||
mock_to_prompt.side_effect = prompt_content_side_effect
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query="current question",
|
||||
sys_files=[sys_file],
|
||||
context="",
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="Before query",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
stop=("STOP",),
|
||||
memory_config=MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(enabled=False),
|
||||
),
|
||||
vision_enabled=True,
|
||||
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=VariablePool.empty(),
|
||||
jinja2_variables=[],
|
||||
context_files=[context_file],
|
||||
)
|
||||
|
||||
assert stop == ("STOP",)
|
||||
assert prompt_messages[0] == UserPromptMessage(content="Before query")
|
||||
assert prompt_messages[1] == AssistantPromptMessage(content="history answer")
|
||||
assert isinstance(prompt_messages[2], UserPromptMessage)
|
||||
assert isinstance(prompt_messages[2].content, list)
|
||||
assert isinstance(prompt_messages[2].content[0], ImagePromptMessageContent)
|
||||
assert isinstance(prompt_messages[2].content[1], ImagePromptMessageContent)
|
||||
assert isinstance(prompt_messages[2].content[2], TextPromptMessageContent)
|
||||
assert prompt_messages[2].content[0].url == "https://example.com/context.png"
|
||||
assert prompt_messages[2].content[1].url == "https://example.com/sys.png"
|
||||
assert prompt_messages[2].content[2].data == "current question"
|
||||
memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=2000, message_limit=None)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_completion_mode_injects_histories_and_query():
|
||||
model_instance = _build_prepared_llm_mock()
|
||||
model_instance.get_model_schema.return_value = _build_model_schema(features=[])
|
||||
|
||||
memory = mock.MagicMock(spec=MockTokenBufferMemory)
|
||||
memory.get_history_prompt_messages.return_value = [
|
||||
UserPromptMessage(content="previous question"),
|
||||
AssistantPromptMessage(content="previous answer"),
|
||||
]
|
||||
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query="latest question",
|
||||
sys_files=[],
|
||||
context="",
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
prompt_template=LLMNodeCompletionModelPromptTemplate(
|
||||
text="Prompt header\n#histories#",
|
||||
edition_type="basic",
|
||||
),
|
||||
stop=("HALT",),
|
||||
memory_config=MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||
window=MemoryConfig.WindowConfig(enabled=True, size=2),
|
||||
),
|
||||
vision_enabled=False,
|
||||
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=VariablePool.empty(),
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
assert stop == ("HALT",)
|
||||
assert prompt_messages == [
|
||||
UserPromptMessage(
|
||||
content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_raises_when_only_unsupported_content_remains():
|
||||
model_instance = _build_prepared_llm_mock()
|
||||
model_instance.get_model_schema.return_value = _build_model_schema(features=[])
|
||||
|
||||
variable_pool = VariablePool.empty()
|
||||
variable_pool.add(
|
||||
["input", "image"],
|
||||
_build_image_file(file_id="image-file", related_id="image-related", remote_url="https://example.com/file.png"),
|
||||
)
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.node.file_manager.to_prompt_message_content",
|
||||
return_value=ImagePromptMessageContent(
|
||||
url="https://example.com/file.png",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
),
|
||||
pytest.raises(NoPromptFoundError, match="No prompt found"),
|
||||
):
|
||||
LLMNode.fetch_prompt_messages(
|
||||
sys_query=None,
|
||||
sys_files=[],
|
||||
context="",
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="{{#input.image#}}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
stop=None,
|
||||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
|
||||
def test_handle_completion_template_replaces_double_brace_context_placeholder(llm_node):
|
||||
prompt_messages = _handle_completion_template(
|
||||
template=LLMNodeCompletionModelPromptTemplate(
|
||||
@ -986,6 +1418,286 @@ class TestReasoningFormat:
|
||||
assert reasoning_content == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("structured_output_enabled", "structured_output"),
|
||||
[
|
||||
(False, None),
|
||||
(True, {"schema": {"type": "object", "properties": {"answer": {"type": "string"}}}}),
|
||||
],
|
||||
)
|
||||
def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enabled, structured_output):
|
||||
model_instance = _build_prepared_llm_mock()
|
||||
prompt_messages = [UserPromptMessage(content="hello")]
|
||||
file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
|
||||
model_instance.invoke_llm.return_value = iter([])
|
||||
model_instance.invoke_llm_with_structured_output.return_value = iter([])
|
||||
|
||||
with (
|
||||
mock.patch.object(LLMNode, "handle_invoke_result", return_value=iter(["handled"])) as mock_handle,
|
||||
mock.patch("dify_graph.nodes.llm.node.time.perf_counter", return_value=10.0),
|
||||
):
|
||||
result = list(
|
||||
LLMNode.invoke_llm(
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=("STOP",),
|
||||
structured_output_enabled=structured_output_enabled,
|
||||
structured_output=structured_output,
|
||||
file_saver=file_saver,
|
||||
file_outputs=[],
|
||||
node_id="node-1",
|
||||
node_type=LLMNode.node_type,
|
||||
reasoning_format="separated",
|
||||
)
|
||||
)
|
||||
|
||||
assert result == ["handled"]
|
||||
if structured_output_enabled:
|
||||
model_instance.invoke_llm_with_structured_output.assert_called_once_with(
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema={"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
model_parameters={},
|
||||
stop=("STOP",),
|
||||
stream=True,
|
||||
)
|
||||
model_instance.invoke_llm.assert_not_called()
|
||||
else:
|
||||
model_instance.invoke_llm.assert_called_once_with(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={},
|
||||
tools=None,
|
||||
stop=("STOP",),
|
||||
stream=True,
|
||||
)
|
||||
model_instance.invoke_llm_with_structured_output.assert_not_called()
|
||||
|
||||
assert mock_handle.call_args.kwargs["request_start_time"] == 10.0
|
||||
|
||||
|
||||
def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_output():
|
||||
usage = LLMUsage.from_metadata({"prompt_tokens": 12, "completion_tokens": 4, "total_tokens": 16})
|
||||
first_chunk = LLMResultChunkWithStructuredOutput(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=[TextPromptMessageContent(data="<think>plan</think>")]),
|
||||
),
|
||||
structured_output={"draft": True},
|
||||
)
|
||||
final_chunk = LLMResultChunk(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=1,
|
||||
message=AssistantPromptMessage(content=[TextPromptMessageContent(data="answer")]),
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
|
||||
with mock.patch("dify_graph.nodes.llm.node.time.perf_counter", side_effect=[2.0, 5.0]):
|
||||
events = list(
|
||||
LLMNode.handle_invoke_result(
|
||||
invoke_result=iter([first_chunk, final_chunk]),
|
||||
file_saver=mock.MagicMock(spec=LLMFileSaver),
|
||||
file_outputs=[],
|
||||
node_id="node-1",
|
||||
node_type=LLMNode.node_type,
|
||||
model_instance=_build_prepared_llm_mock(),
|
||||
reasoning_format="separated",
|
||||
request_start_time=1.0,
|
||||
)
|
||||
)
|
||||
|
||||
assert events[0] == first_chunk
|
||||
assert events[1] == StreamChunkEvent(selector=["node-1", "text"], chunk="<think>plan</think>", is_final=False)
|
||||
assert events[2] == StreamChunkEvent(selector=["node-1", "text"], chunk="answer", is_final=False)
|
||||
|
||||
completed = events[3]
|
||||
assert isinstance(completed, ModelInvokeCompletedEvent)
|
||||
assert completed.text == "answer"
|
||||
assert completed.reasoning_content == "plan"
|
||||
assert completed.structured_output == {"draft": True}
|
||||
assert completed.finish_reason == "stop"
|
||||
assert completed.usage.total_tokens == 16
|
||||
assert completed.usage.latency == 4.0
|
||||
assert completed.usage.time_to_first_token == 1.0
|
||||
assert completed.usage.time_to_generate == 3.0
|
||||
|
||||
|
||||
def test_handle_invoke_result_wraps_structured_output_parse_errors():
|
||||
model_instance = _build_prepared_llm_mock()
|
||||
model_instance.is_structured_output_parse_error.return_value = True
|
||||
|
||||
def broken_stream():
|
||||
raise ValueError("bad json")
|
||||
yield
|
||||
|
||||
with pytest.raises(LLMNodeError, match="Failed to parse structured output: bad json"):
|
||||
list(
|
||||
LLMNode.handle_invoke_result(
|
||||
invoke_result=broken_stream(),
|
||||
file_saver=mock.MagicMock(spec=LLMFileSaver),
|
||||
file_outputs=[],
|
||||
node_id="node-1",
|
||||
node_type=LLMNode.node_type,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_handle_blocking_result_extracts_reasoning_and_structured_output():
|
||||
invoke_result = LLMResultWithStructuredOutput(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content="<think>reasoning</think>final answer"),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
structured_output={"answer": "final answer"},
|
||||
)
|
||||
|
||||
event = LLMNode.handle_blocking_result(
|
||||
invoke_result=invoke_result,
|
||||
saver=mock.MagicMock(spec=LLMFileSaver),
|
||||
file_outputs=[],
|
||||
reasoning_format="separated",
|
||||
request_latency=1.2345,
|
||||
)
|
||||
|
||||
assert event.text == "final answer"
|
||||
assert event.reasoning_content == "reasoning"
|
||||
assert event.structured_output == {"answer": "final answer"}
|
||||
assert event.usage.latency == 1.234
|
||||
|
||||
|
||||
def test_fetch_structured_output_schema_validates_payload():
|
||||
assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object"}}) == {
|
||||
"type": "object"
|
||||
}
|
||||
|
||||
with pytest.raises(LLMNodeError, match="Please provide a valid structured output schema"):
|
||||
LLMNode.fetch_structured_output_schema(structured_output={})
|
||||
|
||||
with pytest.raises(LLMNodeError, match="structured_output_schema must be a JSON object"):
|
||||
LLMNode.fetch_structured_output_schema(structured_output={"schema": ["not", "an", "object"]})
|
||||
|
||||
|
||||
def test_extract_variable_selector_to_variable_mapping_includes_runtime_selectors():
|
||||
node_data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="Hello {{#input.name#}}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
LLMNodeChatModelMessage(
|
||||
text="ignored",
|
||||
jinja2_text="Hello {{ name }}",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="jinja2",
|
||||
),
|
||||
],
|
||||
prompt_config=PromptConfig(
|
||||
jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])]
|
||||
),
|
||||
memory=MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(enabled=True, size=1),
|
||||
query_prompt_template="Repeat {{#sys.query#}}",
|
||||
),
|
||||
context=ContextConfig(enabled=True, variable_selector=["context", "value"]),
|
||||
vision=VisionConfig(enabled=True),
|
||||
)
|
||||
|
||||
mapping = LLMNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config={},
|
||||
node_id="llm-1",
|
||||
node_data=node_data,
|
||||
)
|
||||
|
||||
assert mapping == {
|
||||
"llm-1.#input.name#": ["input", "name"],
|
||||
"llm-1.#sys.query#": ["sys", "query"],
|
||||
"llm-1.#context#": ["context", "value"],
|
||||
"llm-1.#files#": ["sys", "files"],
|
||||
"llm-1.name": ["input", "name"],
|
||||
}
|
||||
|
||||
|
||||
def test_render_jinja2_message_requires_renderer_and_passes_inputs():
|
||||
variable_pool = VariablePool.empty()
|
||||
variable_pool.add(["input", "name"], "Dify")
|
||||
variables = [VariableSelector(variable="name", value_selector=["input", "name"])]
|
||||
|
||||
with pytest.raises(
|
||||
TemplateRenderError,
|
||||
match="LLMNode requires an injected jinja2_template_renderer for jinja2 prompts",
|
||||
):
|
||||
_render_jinja2_message(
|
||||
template="Hello {{ name }}",
|
||||
jinja2_variables=variables,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_template_renderer=None,
|
||||
)
|
||||
|
||||
renderer = mock.MagicMock()
|
||||
renderer.render_template.return_value = "Hello Dify"
|
||||
|
||||
assert (
|
||||
_render_jinja2_message(
|
||||
template="Hello {{ name }}",
|
||||
jinja2_variables=variables,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_template_renderer=renderer,
|
||||
)
|
||||
== "Hello Dify"
|
||||
)
|
||||
renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"})
|
||||
|
||||
|
||||
def test_calculate_rest_token_uses_context_size_and_max_tokens():
|
||||
model_instance = _build_prepared_llm_mock()
|
||||
model_instance.parameters = {"max_tokens": 512}
|
||||
model_instance.get_model_schema.return_value = _build_model_schema(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="max_tokens",
|
||||
label=I18nObject(en_US="Max Tokens"),
|
||||
type=ParameterType.INT,
|
||||
)
|
||||
],
|
||||
)
|
||||
model_instance.get_llm_num_tokens.return_value = 1000
|
||||
|
||||
assert (
|
||||
_calculate_rest_token(
|
||||
prompt_messages=[UserPromptMessage(content="hello")],
|
||||
model_instance=model_instance,
|
||||
)
|
||||
== 2584
|
||||
)
|
||||
|
||||
|
||||
def test_handle_memory_chat_mode_uses_calculated_token_budget():
|
||||
memory = mock.MagicMock(spec=MockTokenBufferMemory)
|
||||
history = [UserPromptMessage(content="question")]
|
||||
memory.get_history_prompt_messages.return_value = history
|
||||
|
||||
with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=321) as mock_rest_token:
|
||||
result = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)),
|
||||
model_instance=_build_prepared_llm_mock(),
|
||||
)
|
||||
|
||||
assert result == history
|
||||
mock_rest_token.assert_called_once()
|
||||
memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2)
|
||||
|
||||
|
||||
def test_dify_model_access_adapters_skip_runtime_build_when_managers_are_injected():
|
||||
run_context = DifyRunContext(
|
||||
tenant_id="tenant",
|
||||
|
||||
Reference in New Issue
Block a user