mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
refactor: llm decouple code executor module (#33400)
Co-authored-by: Byron.wang <byron@dify.ai>
This commit is contained in:
@ -20,7 +20,7 @@ from dify_graph.nodes.code import CodeNode
|
||||
from dify_graph.nodes.document_extractor import DocumentExtractorNode
|
||||
from dify_graph.nodes.http_request import HttpRequestNode
|
||||
from dify_graph.nodes.llm import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNode
|
||||
@ -68,6 +68,8 @@ class MockNodeMixin:
|
||||
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
|
||||
# LLM-like nodes now require an http_client; provide a mock by default for tests.
|
||||
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode)):
|
||||
kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer))
|
||||
|
||||
# Ensure TemplateTransformNode receives a renderer now required by constructor
|
||||
if isinstance(self, TemplateTransformNode):
|
||||
|
||||
@ -34,8 +34,8 @@ from dify_graph.nodes.llm.entities import (
|
||||
VisionConfigOptions,
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
||||
from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
@ -107,6 +107,7 @@ def llm_node(
|
||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@ -121,6 +122,7 @@ def llm_node(
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
template_renderer=mock_template_renderer,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node
|
||||
@ -590,6 +592,33 @@ def test_handle_list_messages_basic(llm_node):
|
||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||
|
||||
|
||||
def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
|
||||
llm_node._template_renderer.render_jinja2.return_value = "Hello, world"
|
||||
messages = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="",
|
||||
jinja2_text="Hello, {{ name }}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="jinja2",
|
||||
)
|
||||
]
|
||||
|
||||
result = llm_node.handle_list_messages(
|
||||
messages=messages,
|
||||
context=None,
|
||||
jinja2_variables=[],
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
template_renderer=llm_node._template_renderer,
|
||||
)
|
||||
|
||||
assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])]
|
||||
llm_node._template_renderer.render_jinja2.assert_called_once_with(
|
||||
template="Hello, {{ name }}",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
|
||||
def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||
memory = mock.MagicMock(spec=MockTokenBufferMemory)
|
||||
memory.get_history_prompt_messages.return_value = [
|
||||
@ -613,8 +642,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||
window=MemoryConfig.WindowConfig(enabled=True, size=3),
|
||||
)
|
||||
|
||||
with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token:
|
||||
memory_text = llm_utils.handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
@ -630,6 +659,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@ -644,6 +674,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
template_renderer=mock_template_renderer,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
@ -1,5 +1,14 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.nodes.question_classifier import (
|
||||
QuestionClassifierNode,
|
||||
QuestionClassifierNodeData,
|
||||
)
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def test_init_question_classifier_node_data():
|
||||
@ -65,3 +74,52 @@ def test_init_question_classifier_node_data_without_vision_config():
|
||||
assert node_data.vision.enabled == False
|
||||
assert node_data.vision.configs.variable_selector == ["sys", "files"]
|
||||
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
|
||||
def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch):
|
||||
node_data = QuestionClassifierNodeData.model_validate(
|
||||
{
|
||||
"title": "test classifier node",
|
||||
"query_variable_selector": ["id", "name"],
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
|
||||
"classes": [{"id": "1", "name": "class 1"}],
|
||||
"instruction": "This is a test instruction",
|
||||
}
|
||||
)
|
||||
template_renderer = MagicMock(spec=TemplateRenderer)
|
||||
node = QuestionClassifierNode(
|
||||
id="node-id",
|
||||
config={"id": "node-id", "data": node_data.model_dump(mode="json")},
|
||||
graph_init_params=build_test_graph_init_params(
|
||||
workflow_id="workflow-id",
|
||||
graph_config={},
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
),
|
||||
graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()),
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(),
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
llm_file_saver=MagicMock(),
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
fetch_prompt_messages = MagicMock(return_value=([], None))
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages",
|
||||
fetch_prompt_messages,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema",
|
||||
MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])),
|
||||
)
|
||||
|
||||
node._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query="hello",
|
||||
model_instance=MagicMock(stop=(), parameters={}),
|
||||
context="",
|
||||
)
|
||||
|
||||
assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer
|
||||
|
||||
@ -140,6 +140,29 @@ class TestDefaultWorkflowCodeExecutor:
|
||||
assert executor.is_execution_error(RuntimeError("boom")) is False
|
||||
|
||||
|
||||
class TestDefaultLLMTemplateRenderer:
|
||||
def test_render_jinja2_delegates_to_code_executor(self, monkeypatch):
|
||||
renderer = node_factory.DefaultLLMTemplateRenderer()
|
||||
execute_workflow_code_template = MagicMock(return_value={"result": "hello world"})
|
||||
monkeypatch.setattr(
|
||||
node_factory.CodeExecutor,
|
||||
"execute_workflow_code_template",
|
||||
execute_workflow_code_template,
|
||||
)
|
||||
|
||||
result = renderer.render_jinja2(
|
||||
template="Hello {{ name }}",
|
||||
inputs={"name": "world"},
|
||||
)
|
||||
|
||||
assert result == "hello world"
|
||||
execute_workflow_code_template.assert_called_once_with(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code="Hello {{ name }}",
|
||||
inputs={"name": "world"},
|
||||
)
|
||||
|
||||
|
||||
class TestDifyNodeFactoryInit:
|
||||
def test_init_builds_default_dependencies(self):
|
||||
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
||||
@ -150,6 +173,7 @@ class TestDifyNodeFactoryInit:
|
||||
http_request_config = sentinel.http_request_config
|
||||
credentials_provider = sentinel.credentials_provider
|
||||
model_factory = sentinel.model_factory
|
||||
llm_template_renderer = sentinel.llm_template_renderer
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
@ -172,6 +196,11 @@ class TestDifyNodeFactoryInit:
|
||||
"build_http_request_config",
|
||||
return_value=http_request_config,
|
||||
),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"DefaultLLMTemplateRenderer",
|
||||
return_value=llm_template_renderer,
|
||||
) as llm_renderer_factory,
|
||||
patch.object(
|
||||
node_factory,
|
||||
"build_dify_model_access",
|
||||
@ -186,11 +215,14 @@ class TestDifyNodeFactoryInit:
|
||||
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
|
||||
build_dify_model_access.assert_called_once_with("tenant-id")
|
||||
renderer_factory.assert_called_once()
|
||||
llm_renderer_factory.assert_called_once()
|
||||
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
|
||||
assert factory.graph_init_params is graph_init_params
|
||||
assert factory.graph_runtime_state is graph_runtime_state
|
||||
assert factory._dify_context is dify_context
|
||||
assert factory._template_renderer is template_renderer
|
||||
|
||||
assert factory._llm_template_renderer is llm_template_renderer
|
||||
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
|
||||
assert factory._http_request_config is http_request_config
|
||||
assert factory._llm_credentials_provider is credentials_provider
|
||||
@ -242,6 +274,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
factory._code_executor = sentinel.code_executor
|
||||
factory._code_limits = sentinel.code_limits
|
||||
factory._template_renderer = sentinel.template_renderer
|
||||
factory._llm_template_renderer = sentinel.llm_template_renderer
|
||||
factory._template_transform_max_output_length = 2048
|
||||
factory._http_request_http_client = sentinel.http_client
|
||||
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
|
||||
@ -378,8 +411,22 @@ class TestDifyNodeFactoryCreateNode:
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||
[
|
||||
(BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}),
|
||||
(BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
|
||||
(
|
||||
BuiltinNodeTypes.LLM,
|
||||
"LLMNode",
|
||||
{
|
||||
"http_client": sentinel.http_client,
|
||||
"template_renderer": sentinel.llm_template_renderer,
|
||||
},
|
||||
),
|
||||
(
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
"QuestionClassifierNode",
|
||||
{
|
||||
"http_client": sentinel.http_client,
|
||||
"template_renderer": sentinel.llm_template_renderer,
|
||||
},
|
||||
),
|
||||
(BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user