diff --git a/api/.importlinter b/api/.importlinter index 4109c007d9..a836d09088 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -103,7 +103,6 @@ ignore_imports = dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.llm.node -> core.helper.code_executor dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output dify_graph.nodes.llm.node -> core.model_manager diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index ee3b322636..ab34263a79 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -45,6 +45,7 @@ from dify_graph.nodes.document_extractor import UnstructuredApiConfig from dify_graph.nodes.http_request import build_http_request_config from dify_graph.nodes.llm.entities import LLMNodeData from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from dify_graph.nodes.llm.protocols import TemplateRenderer from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData from dify_graph.nodes.template_transform.template_renderer import ( @@ -228,6 +229,16 @@ class DefaultWorkflowCodeExecutor: return isinstance(error, CodeExecutionError) +class DefaultLLMTemplateRenderer(TemplateRenderer): + def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=inputs, + ) + return str(result.get("result", "")) + + @final class DifyNodeFactory(NodeFactory): """ @@ -254,6 +265,7 @@ class DifyNodeFactory(NodeFactory): max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) + self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH self._http_request_http_client = ssrf_proxy self._http_request_tool_file_manager_factory = ToolFileManager @@ -391,6 +403,8 @@ class DifyNodeFactory(NodeFactory): model_instance=model_instance, ), } + if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}: + node_init_kwargs["template_renderer"] = self._llm_template_renderer if include_http_client: node_init_kwargs["http_client"] = self._http_request_http_client return node_init_kwargs diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py index ca478a09f8..073dce232f 100644 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ b/api/dify_graph/nodes/llm/llm_utils.py @@ -1,34 +1,53 @@ +from __future__ import annotations + from collections.abc import Sequence -from typing import cast +from typing import Any, cast from core.model_manager import ModelInstance +from dify_graph.file import FileType, file_manager from dify_graph.file.models import File -from dify_graph.model_runtime.entities import PromptMessageRole -from dify_graph.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, + PromptMessageContentType, + PromptMessageRole, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageContentUnionTypes, + SystemPromptMessage, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.nodes.base.entities import VariableSelector from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment +from dify_graph.variables import ArrayFileSegment, FileSegment +from dify_graph.variables.segments import ArrayAnySegment, NoneSegment -from .exc import InvalidVariableTypeError +from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig +from .exc import ( + InvalidVariableTypeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, +) +from .protocols import TemplateRenderer def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( model_instance.model_name, - model_instance.credentials, + dict(model_instance.credentials), ) if not model_schema: raise ValueError(f"Model schema not found for {model_instance.model_name}") return model_schema -def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]: +def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]: variable = variable_pool.get(selector) if variable is None: return [] @@ -89,3 +108,366 @@ def fetch_memory_text( human_prefix=human_prefix, ai_prefix=ai_prefix, ) + + +def fetch_prompt_messages( + *, + sys_query: str | None = None, + sys_files: Sequence[File], + context: str | None = None, + memory: PromptMessageMemory | None = None, + model_instance: ModelInstance, + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + stop: Sequence[str] | None = None, + memory_config: MemoryConfig | None = None, + vision_enabled: bool = False, + vision_detail: ImagePromptMessageContent.DETAIL, + variable_pool: VariablePool, + jinja2_variables: Sequence[VariableSelector], + context_files: list[File] | None = None, + template_renderer: TemplateRenderer | None = None, +) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: + prompt_messages: list[PromptMessage] = [] + model_schema = fetch_model_schema(model_instance=model_instance) + + if isinstance(prompt_template, list): + prompt_messages.extend( + handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + template_renderer=template_renderer, + ) + ) + + prompt_messages.extend( + handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + ) + + if sys_query: + prompt_messages.extend( + handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text=sys_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + template_renderer=template_renderer, + ) + ) + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + prompt_messages.extend( + handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + template_renderer=template_renderer, + ) + ) + + memory_text = handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + prompt_content = prompt_messages[0].content + if isinstance(prompt_content, str): + prompt_content = str(prompt_content) + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + if "#histories#" in content_item.data: + content_item.data = content_item.data.replace("#histories#", memory_text) + else: + content_item.data = memory_text + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + + if sys_query: + if isinstance(prompt_content, str): + prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + content_item.data = sys_query + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + else: + raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) + + _append_file_prompts( + prompt_messages=prompt_messages, + files=sys_files, + vision_enabled=vision_enabled, + vision_detail=vision_detail, + ) + _append_file_prompts( + prompt_messages=prompt_messages, + files=context_files or [], + vision_enabled=vision_enabled, + vision_detail=vision_detail, + ) + + filtered_prompt_messages: list[PromptMessage] = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message.content, list): + prompt_message_content: list[PromptMessageContentUnionTypes] = [] + for content_item in prompt_message.content: + if not model_schema.features: + if content_item.type == PromptMessageContentType.TEXT: + prompt_message_content.append(content_item) + continue + + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and ModelFeature.VISION not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_schema.features + ) + ): + continue + prompt_message_content.append(content_item) + if prompt_message_content: + prompt_message.content = prompt_message_content + filtered_prompt_messages.append(prompt_message) + elif not prompt_message.is_empty(): + filtered_prompt_messages.append(prompt_message) + + if len(filtered_prompt_messages) == 0: + raise NoPromptFoundError( + "No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding." + ) + + return filtered_prompt_messages, stop + + +def handle_list_messages( + *, + messages: Sequence[LLMNodeChatModelMessage], + context: str | None, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + vision_detail_config: ImagePromptMessageContent.DETAIL, + template_renderer: TemplateRenderer | None = None, +) -> Sequence[PromptMessage]: + prompt_messages: list[PromptMessage] = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = render_jinja2_message( + template=message.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + template_renderer=template_renderer, + ) + prompt_messages.append( + combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], + role=message.role, + ) + ) + continue + + template = message.text.replace("{#context#}", context) if context else message.text + segment_group = variable_pool.convert_template(template) + file_contents: list[PromptMessageContentUnionTypes] = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) + ) + elif isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) + ) + + if segment_group.text: + prompt_messages.append( + combine_message_content_with_role( + contents=[TextPromptMessageContent(data=segment_group.text)], + role=message.role, + ) + ) + if file_contents: + prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role)) + + return prompt_messages + + +def render_jinja2_message( + *, + template: str, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + template_renderer: TemplateRenderer | None = None, +) -> str: + if not template: + return "" + if template_renderer is None: + raise ValueError("template_renderer is required for jinja2 prompt rendering") + + jinja2_inputs: dict[str, Any] = {} + for jinja2_variable in jinja2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs) + + +def handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: str | None, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + template_renderer: TemplateRenderer | None = None, +) -> Sequence[PromptMessage]: + if template.edition_type == "jinja2": + result_text = render_jinja2_message( + template=template.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + template_renderer=template_renderer, + ) + else: + template_text = template.text.replace("{#context#}", context) if context else template.text + result_text = variable_pool.convert_template(template_text).text + return [ + combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], + role=PromptMessageRole.USER, + ) + ] + + +def combine_message_content_with_role( + *, + contents: str | list[PromptMessageContentUnionTypes] | None = None, + role: PromptMessageRole, +) -> PromptMessage: + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=contents) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=contents) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=contents) + case _: + raise NotImplementedError(f"Role {role} is not supported") + + +def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int: + rest_tokens = 2000 + runtime_model_schema = fetch_model_schema(model_instance=model_instance) + runtime_model_parameters = model_instance.parameters + + model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in runtime_model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + runtime_model_parameters.get(parameter_rule.name) + or runtime_model_parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def handle_memory_chat_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: ModelInstance, +) -> Sequence[PromptMessage]: + if not memory or not memory_config: + return [] + rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) + return memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + + +def handle_memory_completion_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: ModelInstance, +) -> str: + if not memory or not memory_config: + return "" + + rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + + return fetch_memory_text( + memory=memory, + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + + +def _append_file_prompts( + *, + prompt_messages: list[PromptMessage], + files: Sequence[File], + vision_enabled: bool, + vision_detail: ImagePromptMessageContent.DETAIL, +) -> None: + if not vision_enabled or not files: + return + + file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files] + if ( + prompt_messages + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + existing_contents = prompt_messages[-1].content + assert isinstance(existing_contents, list) + prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index c3529867b7..5ed90ed7e3 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Literal from sqlalchemy import select -from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelInstance @@ -28,11 +27,10 @@ from dify_graph.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File, FileTransferMethod, FileType, file_manager +from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, - PromptMessageContentType, TextPromptMessageContent, ) from dify_graph.model_runtime.entities.llm_entities import ( @@ -43,14 +41,7 @@ from dify_graph.model_runtime.entities.llm_entities import ( LLMStructuredOutput, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - PromptMessageRole, - SystemPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import ( @@ -64,13 +55,12 @@ from dify_graph.node_events import ( from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.runtime import VariablePool from dify_graph.variables import ( ArrayFileSegment, ArraySegment, - FileSegment, NoneSegment, ObjectSegment, StringSegment, @@ -89,9 +79,6 @@ from .exc import ( InvalidContextStructureError, InvalidVariableTypeError, LLMNodeError, - MemoryRolePrefixRequiredError, - NoPromptFoundError, - TemplateTypeNotSupportError, VariableNotFoundError, ) from .file_saver import FileSaverImpl, LLMFileSaver @@ -118,6 +105,7 @@ class LLMNode(Node[LLMNodeData]): _model_factory: ModelFactory _model_instance: ModelInstance _memory: PromptMessageMemory | None + _template_renderer: TemplateRenderer def __init__( self, @@ -130,6 +118,7 @@ class LLMNode(Node[LLMNodeData]): model_factory: ModelFactory, model_instance: ModelInstance, http_client: HttpClientProtocol, + template_renderer: TemplateRenderer, memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): @@ -146,6 +135,7 @@ class LLMNode(Node[LLMNodeData]): self._model_factory = model_factory self._model_instance = model_instance self._memory = memory + self._template_renderer = template_renderer if llm_file_saver is None: dify_ctx = self.require_dify_context() @@ -240,6 +230,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, context_files=context_files, + template_renderer=self._template_renderer, ) # handle invoke result @@ -773,182 +764,24 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, + template_renderer: TemplateRenderer | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - prompt_messages: list[PromptMessage] = [] - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - if isinstance(prompt_template, list): - # For chat model - prompt_messages.extend( - LLMNode.handle_list_messages( - messages=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail, - ) - ) - - # Get memory messages for chat mode - memory_messages = _handle_memory_chat_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - # Extend prompt_messages with memory messages - prompt_messages.extend(memory_messages) - - # Add current query to the prompt messages - if sys_query: - message = LLMNodeChatModelMessage( - text=sys_query, - role=PromptMessageRole.USER, - edition_type="basic", - ) - prompt_messages.extend( - LLMNode.handle_list_messages( - messages=[message], - context="", - jinja2_variables=[], - variable_pool=variable_pool, - vision_detail_config=vision_detail, - ) - ) - - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - # For completion model - prompt_messages.extend( - _handle_completion_template( - template=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - ) - ) - - # Get memory text for completion model - memory_text = _handle_memory_completion_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - # Insert histories into the prompt - prompt_content = prompt_messages[0].content - # For issue #11247 - Check if prompt content is a string or a list - if isinstance(prompt_content, str): - prompt_content = str(prompt_content) - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) - else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - if "#histories#" in content_item.data: - content_item.data = content_item.data.replace("#histories#", memory_text) - else: - content_item.data = memory_text + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - - # Add current query to the prompt message - if sys_query: - if isinstance(prompt_content, str): - prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - content_item.data = sys_query + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - else: - raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - - # The sys_files will be deprecated later - if vision_enabled and sys_files: - file_prompts = [] - for file in sys_files: - file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) - file_prompts.append(file_prompt) - # If last prompt is a user prompt, add files into its contents, - # otherwise append a new user prompt - if ( - len(prompt_messages) > 0 - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - # The context_files - if vision_enabled and context_files: - file_prompts = [] - for file in context_files: - file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) - file_prompts.append(file_prompt) - # If last prompt is a user prompt, add files into its contents, - # otherwise append a new user prompt - if ( - len(prompt_messages) > 0 - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - # Remove empty messages and filter unsupported content - filtered_prompt_messages = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message.content, list): - prompt_message_content: list[PromptMessageContentUnionTypes] = [] - for content_item in prompt_message.content: - # Skip content if features are not defined - if not model_schema.features: - if content_item.type != PromptMessageContentType.TEXT: - continue - prompt_message_content.append(content_item) - continue - - # Skip content if corresponding feature is not supported - if ( - ( - content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_schema.features - ) - ): - continue - prompt_message_content.append(content_item) - if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: - prompt_message.content = prompt_message_content[0].data - else: - prompt_message.content = prompt_message_content - if prompt_message.is_empty(): - continue - filtered_prompt_messages.append(prompt_message) - - if len(filtered_prompt_messages) == 0: - raise NoPromptFoundError( - "No prompt found in the LLM configuration. " - "Please ensure a prompt is properly configured before proceeding." - ) - - return filtered_prompt_messages, stop + return llm_utils.fetch_prompt_messages( + sys_query=sys_query, + sys_files=sys_files, + context=context, + memory=memory, + model_instance=model_instance, + prompt_template=prompt_template, + stop=stop, + memory_config=memory_config, + vision_enabled=vision_enabled, + vision_detail=vision_detail, + variable_pool=variable_pool, + jinja2_variables=jinja2_variables, + context_files=context_files, + template_renderer=template_renderer, + ) @classmethod def _extract_variable_selector_to_variable_mapping( @@ -1048,59 +881,16 @@ class LLMNode(Node[LLMNodeData]): jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, + template_renderer: TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: - prompt_messages: list[PromptMessage] = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=message.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - ) - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=message.role - ) - prompt_messages.append(prompt_message) - else: - # Get segment group from basic message - if context: - template = message.text.replace("{#context#}", context) - else: - template = message.text - segment_group = variable_pool.convert_template(template) - - # Process segments for images - file_contents = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - elif isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - - # Create message with text from all segments - plain_text = segment_group.text - if plain_text: - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=plain_text)], role=message.role - ) - prompt_messages.append(prompt_message) - - if file_contents: - # Create message with image contents - prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) - prompt_messages.append(prompt_message) - - return prompt_messages + return llm_utils.handle_list_messages( + messages=messages, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail_config, + template_renderer=template_renderer, + ) @staticmethod def handle_blocking_result( @@ -1239,152 +1029,3 @@ class LLMNode(Node[LLMNodeData]): @property def model_instance(self) -> ModelInstance: return self._model_instance - - -def _combine_message_content_with_role( - *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole -): - match role: - case PromptMessageRole.USER: - return UserPromptMessage(content=contents) - case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=contents) - case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=contents) - case _: - raise NotImplementedError(f"Role {role} is not supported") - - -def _render_jinja2_message( - *, - template: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, -): - if not template: - return "" - - jinja2_inputs = {} - for jinja2_variable in jinja2_variables: - variable = variable_pool.get(jinja2_variable.value_selector) - jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - code_execute_resp = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=template, - inputs=jinja2_inputs, - ) - result_text = code_execute_resp["result"] - return result_text - - -def _calculate_rest_token( - *, - prompt_messages: list[PromptMessage], - model_instance: ModelInstance, -) -> int: - rest_tokens = 2000 - runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - runtime_model_parameters = model_instance.parameters - - model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in runtime_model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - runtime_model_parameters.get(parameter_rule.name) - or runtime_model_parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - -def _handle_memory_chat_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: ModelInstance, -) -> Sequence[PromptMessage]: - memory_messages: Sequence[PromptMessage] = [] - # Get messages from memory for chat model - if memory and memory_config: - rest_tokens = _calculate_rest_token( - prompt_messages=[], - model_instance=model_instance, - ) - memory_messages = memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - return memory_messages - - -def _handle_memory_completion_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: ModelInstance, -) -> str: - memory_text = "" - # Get history text from memory for completion model - if memory and memory_config: - rest_tokens = _calculate_rest_token( - prompt_messages=[], - model_instance=model_instance, - ) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - memory_text = llm_utils.fetch_memory_text( - memory=memory, - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - return memory_text - - -def _handle_completion_template( - *, - template: LLMNodeCompletionModelPromptTemplate, - context: str | None, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, -) -> Sequence[PromptMessage]: - """Handle completion template processing outside of LLMNode class. - - Args: - template: The completion model prompt template - context: Optional context string - jinja2_variables: Variables for jinja2 template rendering - variable_pool: Variable pool for template conversion - - Returns: - Sequence of prompt messages - """ - prompt_messages = [] - if template.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=template.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - ) - else: - if context: - template_text = template.text.replace("{#context#}", context) - else: - template_text = template.text - result_text = variable_pool.convert_template(template_text).text - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER - ) - prompt_messages.append(prompt_message) - return prompt_messages diff --git a/api/dify_graph/nodes/llm/protocols.py b/api/dify_graph/nodes/llm/protocols.py index 8e0365299d..9e95d341c9 100644 --- a/api/dify_graph/nodes/llm/protocols.py +++ b/api/dify_graph/nodes/llm/protocols.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Mapping from typing import Any, Protocol from core.model_manager import ModelInstance @@ -19,3 +20,11 @@ class ModelFactory(Protocol): def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: """Create a model instance that is ready for schema lookup and invocation.""" ... + + +class TemplateRenderer(Protocol): + """Port for rendering prompt templates used by LLM-compatible nodes.""" + + def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: + """Render the given Jinja2 template into plain text.""" + ... diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 84e77a460c..59d0a2a4d8 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -28,7 +28,7 @@ from dify_graph.nodes.llm import ( llm_utils, ) from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer from dify_graph.nodes.protocols import HttpClientProtocol from libs.json_in_md_parser import parse_and_check_json_markdown @@ -59,6 +59,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): _model_factory: "ModelFactory" _model_instance: ModelInstance _memory: PromptMessageMemory | None + _template_renderer: TemplateRenderer def __init__( self, @@ -71,6 +72,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_factory: "ModelFactory", model_instance: ModelInstance, http_client: HttpClientProtocol, + template_renderer: TemplateRenderer, memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): @@ -87,6 +89,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self._model_factory = model_factory self._model_instance = model_instance self._memory = memory + self._template_renderer = template_renderer if llm_file_saver is None: dify_ctx = self.require_dify_context() @@ -142,7 +145,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, # two consecutive user prompts will be generated, causing model's error. # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = LLMNode.fetch_prompt_messages( + prompt_messages, stop = llm_utils.fetch_prompt_messages( prompt_template=prompt_template, sys_query="", memory=memory, @@ -153,6 +156,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): vision_detail=node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=[], + template_renderer=self._template_renderer, ) result_text = "" @@ -287,7 +291,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages, _ = LLMNode.fetch_prompt_messages( + prompt_messages, _ = llm_utils.fetch_prompt_messages( prompt_template=prompt_template, sys_query="", sys_files=[], @@ -300,6 +304,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): vision_detail=node_data.vision.configs.detail, variable_pool=self.graph_runtime_state.variable_pool, jinja2_variables=[], + template_renderer=self._template_renderer, ) rest_tokens = 2000 diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 2aca9f5157..d628348f1e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -10,7 +10,7 @@ from core.model_manager import ModelInstance from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.node_events import StreamCompletedEvent from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -75,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode: credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), + template_renderer=MagicMock(spec=TemplateRenderer), http_client=MagicMock(spec=HttpClientProtocol), ) @@ -158,7 +159,7 @@ def test_execute_llm(): return mock_model_instance # Mock fetch_prompt_messages to avoid database calls - def mock_fetch_prompt_messages_1(**_kwargs): + def mock_fetch_prompt_messages_1(*_args, **_kwargs): from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index e117f81ff9..454263bef9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -20,7 +20,7 @@ from dify_graph.nodes.code import CodeNode from dify_graph.nodes.document_extractor import DocumentExtractorNode from dify_graph.nodes.http_request import HttpRequestNode from dify_graph.nodes.llm import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer from dify_graph.nodes.parameter_extractor import ParameterExtractorNode from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol from dify_graph.nodes.question_classifier import QuestionClassifierNode @@ -68,6 +68,8 @@ class MockNodeMixin: kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) # LLM-like nodes now require an http_client; provide a mock by default for tests. kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) + if isinstance(self, (LLMNode, QuestionClassifierNode)): + kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer)) # Ensure TemplateTransformNode receives a renderer now required by constructor if isinstance(self, TemplateTransformNode): diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index d56035b6bc..fc96088af1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -34,8 +34,8 @@ from dify_graph.nodes.llm.entities import ( VisionConfigOptions, ) from dify_graph.nodes.llm.file_saver import LLMFileSaver -from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.node import LLMNode +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment @@ -107,6 +107,7 @@ def llm_node( mock_file_saver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) + mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -121,6 +122,7 @@ def llm_node( model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + template_renderer=mock_template_renderer, http_client=http_client, ) return node @@ -590,6 +592,33 @@ def test_handle_list_messages_basic(llm_node): assert result[0].content == [TextPromptMessageContent(data="Hello, world")] +def test_handle_list_messages_jinja2_uses_template_renderer(llm_node): + llm_node._template_renderer.render_jinja2.return_value = "Hello, world" + messages = [ + LLMNodeChatModelMessage( + text="", + jinja2_text="Hello, {{ name }}", + role=PromptMessageRole.USER, + edition_type="jinja2", + ) + ] + + result = llm_node.handle_list_messages( + messages=messages, + context=None, + jinja2_variables=[], + variable_pool=llm_node.graph_runtime_state.variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + template_renderer=llm_node._template_renderer, + ) + + assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])] + llm_node._template_renderer.render_jinja2.assert_called_once_with( + template="Hello, {{ name }}", + inputs={}, + ) + + def test_handle_memory_completion_mode_uses_prompt_message_interface(): memory = mock.MagicMock(spec=MockTokenBufferMemory) memory.get_history_prompt_messages.return_value = [ @@ -613,8 +642,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface(): window=MemoryConfig.WindowConfig(enabled=True, size=3), ) - with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: - memory_text = _handle_memory_completion_mode( + with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token: + memory_text = llm_utils.handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_instance=model_instance, @@ -630,6 +659,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) + mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -644,6 +674,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + template_renderer=mock_template_renderer, http_client=http_client, ) return node, mock_file_saver diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py index 4dfec5ef60..c5a02e87e4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -1,5 +1,14 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.nodes.question_classifier import QuestionClassifierNodeData +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.nodes.question_classifier import ( + QuestionClassifierNode, + QuestionClassifierNodeData, +) +from tests.workflow_test_utils import build_test_graph_init_params def test_init_question_classifier_node_data(): @@ -65,3 +74,52 @@ def test_init_question_classifier_node_data_without_vision_config(): assert node_data.vision.enabled == False assert node_data.vision.configs.variable_selector == ["sys", "files"] assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH + + +def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch): + node_data = QuestionClassifierNodeData.model_validate( + { + "title": "test classifier node", + "query_variable_selector": ["id", "name"], + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, + "classes": [{"id": "1", "name": "class 1"}], + "instruction": "This is a test instruction", + } + ) + template_renderer = MagicMock(spec=TemplateRenderer) + node = QuestionClassifierNode( + id="node-id", + config={"id": "node-id", "data": node_data.model_dump(mode="json")}, + graph_init_params=build_test_graph_init_params( + workflow_id="workflow-id", + graph_config={}, + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + ), + graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()), + credentials_provider=MagicMock(spec=CredentialsProvider), + model_factory=MagicMock(spec=ModelFactory), + model_instance=MagicMock(), + http_client=MagicMock(spec=HttpClientProtocol), + llm_file_saver=MagicMock(), + template_renderer=template_renderer, + ) + fetch_prompt_messages = MagicMock(return_value=([], None)) + monkeypatch.setattr( + "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", + fetch_prompt_messages, + ) + monkeypatch.setattr( + "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", + MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])), + ) + + node._calculate_rest_token( + node_data=node_data, + query="hello", + model_instance=MagicMock(stop=(), parameters={}), + context="", + ) + + assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index ab46126ca6..367e3958ad 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -140,6 +140,29 @@ class TestDefaultWorkflowCodeExecutor: assert executor.is_execution_error(RuntimeError("boom")) is False +class TestDefaultLLMTemplateRenderer: + def test_render_jinja2_delegates_to_code_executor(self, monkeypatch): + renderer = node_factory.DefaultLLMTemplateRenderer() + execute_workflow_code_template = MagicMock(return_value={"result": "hello world"}) + monkeypatch.setattr( + node_factory.CodeExecutor, + "execute_workflow_code_template", + execute_workflow_code_template, + ) + + result = renderer.render_jinja2( + template="Hello {{ name }}", + inputs={"name": "world"}, + ) + + assert result == "hello world" + execute_workflow_code_template.assert_called_once_with( + language=CodeLanguage.JINJA2, + code="Hello {{ name }}", + inputs={"name": "world"}, + ) + + class TestDifyNodeFactoryInit: def test_init_builds_default_dependencies(self): graph_init_params = SimpleNamespace(run_context={"context": "value"}) @@ -150,6 +173,7 @@ class TestDifyNodeFactoryInit: http_request_config = sentinel.http_request_config credentials_provider = sentinel.credentials_provider model_factory = sentinel.model_factory + llm_template_renderer = sentinel.llm_template_renderer with ( patch.object( @@ -172,6 +196,11 @@ class TestDifyNodeFactoryInit: "build_http_request_config", return_value=http_request_config, ), + patch.object( + node_factory, + "DefaultLLMTemplateRenderer", + return_value=llm_template_renderer, + ) as llm_renderer_factory, patch.object( node_factory, "build_dify_model_access", @@ -186,11 +215,14 @@ class TestDifyNodeFactoryInit: resolve_dify_context.assert_called_once_with(graph_init_params.run_context) build_dify_model_access.assert_called_once_with("tenant-id") renderer_factory.assert_called_once() + llm_renderer_factory.assert_called_once() assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor assert factory.graph_init_params is graph_init_params assert factory.graph_runtime_state is graph_runtime_state assert factory._dify_context is dify_context assert factory._template_renderer is template_renderer + + assert factory._llm_template_renderer is llm_template_renderer assert factory._document_extractor_unstructured_api_config is unstructured_api_config assert factory._http_request_config is http_request_config assert factory._llm_credentials_provider is credentials_provider @@ -242,6 +274,7 @@ class TestDifyNodeFactoryCreateNode: factory._code_executor = sentinel.code_executor factory._code_limits = sentinel.code_limits factory._template_renderer = sentinel.template_renderer + factory._llm_template_renderer = sentinel.llm_template_renderer factory._template_transform_max_output_length = 2048 factory._http_request_http_client = sentinel.http_client factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory @@ -378,8 +411,22 @@ class TestDifyNodeFactoryCreateNode: @pytest.mark.parametrize( ("node_type", "constructor_name", "expected_extra_kwargs"), [ - (BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}), - (BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}), + ( + BuiltinNodeTypes.LLM, + "LLMNode", + { + "http_client": sentinel.http_client, + "template_renderer": sentinel.llm_template_renderer, + }, + ), + ( + BuiltinNodeTypes.QUESTION_CLASSIFIER, + "QuestionClassifierNode", + { + "http_client": sentinel.http_client, + "template_renderer": sentinel.llm_template_renderer, + }, + ), (BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}), ], )