refactor: llm decouple code executor module (#33400)

Co-authored-by: Byron.wang <byron@dify.ai>
This commit is contained in:
wangxiaolei
2026-03-16 10:06:14 +08:00
committed by GitHub
parent a6163f80d1
commit 6ef69ff880
11 changed files with 603 additions and 414 deletions

View File

@ -20,7 +20,7 @@ from dify_graph.nodes.code import CodeNode
from dify_graph.nodes.document_extractor import DocumentExtractorNode
from dify_graph.nodes.http_request import HttpRequestNode
from dify_graph.nodes.llm import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
from dify_graph.nodes.question_classifier import QuestionClassifierNode
@ -68,6 +68,8 @@ class MockNodeMixin:
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
# LLM-like nodes now require an http_client; provide a mock by default for tests.
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
if isinstance(self, (LLMNode, QuestionClassifierNode)):
kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer))
# Ensure TemplateTransformNode receives a renderer now required by constructor
if isinstance(self, TemplateTransformNode):

View File

@ -34,8 +34,8 @@ from dify_graph.nodes.llm.entities import (
VisionConfigOptions,
)
from dify_graph.nodes.llm.file_saver import LLMFileSaver
from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
@ -107,6 +107,7 @@ def llm_node(
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@ -121,6 +122,7 @@ def llm_node(
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
template_renderer=mock_template_renderer,
http_client=http_client,
)
return node
@ -590,6 +592,33 @@ def test_handle_list_messages_basic(llm_node):
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
llm_node._template_renderer.render_jinja2.return_value = "Hello, world"
messages = [
LLMNodeChatModelMessage(
text="",
jinja2_text="Hello, {{ name }}",
role=PromptMessageRole.USER,
edition_type="jinja2",
)
]
result = llm_node.handle_list_messages(
messages=messages,
context=None,
jinja2_variables=[],
variable_pool=llm_node.graph_runtime_state.variable_pool,
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
template_renderer=llm_node._template_renderer,
)
assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])]
llm_node._template_renderer.render_jinja2.assert_called_once_with(
template="Hello, {{ name }}",
inputs={},
)
def test_handle_memory_completion_mode_uses_prompt_message_interface():
memory = mock.MagicMock(spec=MockTokenBufferMemory)
memory.get_history_prompt_messages.return_value = [
@ -613,8 +642,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface():
window=MemoryConfig.WindowConfig(enabled=True, size=3),
)
with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
memory_text = _handle_memory_completion_mode(
with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token:
memory_text = llm_utils.handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
@ -630,6 +659,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@ -644,6 +674,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
template_renderer=mock_template_renderer,
http_client=http_client,
)
return node, mock_file_saver

View File

@ -1,5 +1,14 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from dify_graph.model_runtime.entities import ImagePromptMessageContent
from dify_graph.nodes.question_classifier import QuestionClassifierNodeData
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.nodes.question_classifier import (
QuestionClassifierNode,
QuestionClassifierNodeData,
)
from tests.workflow_test_utils import build_test_graph_init_params
def test_init_question_classifier_node_data():
@ -65,3 +74,52 @@ def test_init_question_classifier_node_data_without_vision_config():
assert node_data.vision.enabled == False
assert node_data.vision.configs.variable_selector == ["sys", "files"]
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch):
node_data = QuestionClassifierNodeData.model_validate(
{
"title": "test classifier node",
"query_variable_selector": ["id", "name"],
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
"classes": [{"id": "1", "name": "class 1"}],
"instruction": "This is a test instruction",
}
)
template_renderer = MagicMock(spec=TemplateRenderer)
node = QuestionClassifierNode(
id="node-id",
config={"id": "node-id", "data": node_data.model_dump(mode="json")},
graph_init_params=build_test_graph_init_params(
workflow_id="workflow-id",
graph_config={},
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
),
graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()),
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(),
http_client=MagicMock(spec=HttpClientProtocol),
llm_file_saver=MagicMock(),
template_renderer=template_renderer,
)
fetch_prompt_messages = MagicMock(return_value=([], None))
monkeypatch.setattr(
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages",
fetch_prompt_messages,
)
monkeypatch.setattr(
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema",
MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])),
)
node._calculate_rest_token(
node_data=node_data,
query="hello",
model_instance=MagicMock(stop=(), parameters={}),
context="",
)
assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer

View File

@ -140,6 +140,29 @@ class TestDefaultWorkflowCodeExecutor:
assert executor.is_execution_error(RuntimeError("boom")) is False
class TestDefaultLLMTemplateRenderer:
def test_render_jinja2_delegates_to_code_executor(self, monkeypatch):
renderer = node_factory.DefaultLLMTemplateRenderer()
execute_workflow_code_template = MagicMock(return_value={"result": "hello world"})
monkeypatch.setattr(
node_factory.CodeExecutor,
"execute_workflow_code_template",
execute_workflow_code_template,
)
result = renderer.render_jinja2(
template="Hello {{ name }}",
inputs={"name": "world"},
)
assert result == "hello world"
execute_workflow_code_template.assert_called_once_with(
language=CodeLanguage.JINJA2,
code="Hello {{ name }}",
inputs={"name": "world"},
)
class TestDifyNodeFactoryInit:
def test_init_builds_default_dependencies(self):
graph_init_params = SimpleNamespace(run_context={"context": "value"})
@ -150,6 +173,7 @@ class TestDifyNodeFactoryInit:
http_request_config = sentinel.http_request_config
credentials_provider = sentinel.credentials_provider
model_factory = sentinel.model_factory
llm_template_renderer = sentinel.llm_template_renderer
with (
patch.object(
@ -172,6 +196,11 @@ class TestDifyNodeFactoryInit:
"build_http_request_config",
return_value=http_request_config,
),
patch.object(
node_factory,
"DefaultLLMTemplateRenderer",
return_value=llm_template_renderer,
) as llm_renderer_factory,
patch.object(
node_factory,
"build_dify_model_access",
@ -186,11 +215,14 @@ class TestDifyNodeFactoryInit:
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
build_dify_model_access.assert_called_once_with("tenant-id")
renderer_factory.assert_called_once()
llm_renderer_factory.assert_called_once()
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
assert factory.graph_init_params is graph_init_params
assert factory.graph_runtime_state is graph_runtime_state
assert factory._dify_context is dify_context
assert factory._template_renderer is template_renderer
assert factory._llm_template_renderer is llm_template_renderer
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
assert factory._http_request_config is http_request_config
assert factory._llm_credentials_provider is credentials_provider
@ -242,6 +274,7 @@ class TestDifyNodeFactoryCreateNode:
factory._code_executor = sentinel.code_executor
factory._code_limits = sentinel.code_limits
factory._template_renderer = sentinel.template_renderer
factory._llm_template_renderer = sentinel.llm_template_renderer
factory._template_transform_max_output_length = 2048
factory._http_request_http_client = sentinel.http_client
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
@ -378,8 +411,22 @@ class TestDifyNodeFactoryCreateNode:
@pytest.mark.parametrize(
("node_type", "constructor_name", "expected_extra_kwargs"),
[
(BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}),
(BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
(
BuiltinNodeTypes.LLM,
"LLMNode",
{
"http_client": sentinel.http_client,
"template_renderer": sentinel.llm_template_renderer,
},
),
(
BuiltinNodeTypes.QUESTION_CLASSIFIER,
"QuestionClassifierNode",
{
"http_client": sentinel.http_client,
"template_renderer": sentinel.llm_template_renderer,
},
),
(BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
],
)